R"""Computes the max sequence length needed to encode a task given a tokenizer.


Example Usage:

CUDA_VISIBLE_DEVICES= python scripts1/misc/compute_max_sequence_length_for_dataset.py \
    --tokenizer=bert-base-uncased \
    --task=winogrande/xl \
    --splits=train,validation

"""
import os

from absl import app
from absl import flags

import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.util.color_util import cu

FLAGS = flags.FLAGS

flags.DEFINE_string("tokenizer", None, "")

flags.DEFINE_string("task", None, "")
flags.DEFINE_list("splits", ['train'], "")

flags.DEFINE_integer("n_examples", None, "Leave as None to use all examples.")

flags.DEFINE_integer("trial_sequence_length",
                     512,
                     'Must be greater than the actual max sequence length. Larger values will be slower though.')

flags.DEFINE_integer("batch_size", 2048, "Will not affect result but might affect speed.")
flags.DEFINE_integer("prefetch_size", 4, "Will not affect result but might affect speed.")


def get_extra_task_kwargs():
    task = FLAGS.task
    if task.startswith('winogrande'):
        return {
            'force_deterministic': True,
        }
    return {}


def get_max_sequence_length(tokenizer, batch: np.ndarray):
    pad_token = tokenizer.pad_token_id

    full_sequence_length = batch.shape[-1]
    min_padding = (batch == pad_token).sum(axis=-1).min()

    if min_padding == 0:
        raise ValueError(
            f'An example used at least {full_sequence_length} tokens. Please increase the --trial_sequence_length flag.')

    return full_sequence_length - min_padding


def _get_input_ids_fn(x, y):
    return x['input_ids']


def main(_):
    tokenizer = AutoTokenizer.from_pretrained(os.path.expanduser(FLAGS.tokenizer))

    maxlens = []

    for split in FLAGS.splits:
        ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.trial_sequence_length,
            **get_extra_task_kwargs(),
        )
        ds = ds.map(_get_input_ids_fn).batch(FLAGS.batch_size).prefetch(FLAGS.prefetch_size)

        maxlen = 0

        for batch in tqdm(ds.as_numpy_iterator()):
            maxlen = max(maxlen, get_max_sequence_length(tokenizer, batch))

        maxlens.append(maxlen)

        print(f'Max sequence length for split {cu.lc(split)}: {cu.hly(maxlen)}')

    print('Summary:')
    for maxlen, split in zip(maxlens, FLAGS.splits):
        print(f'    Max sequence length for split {cu.lc(split)}: {cu.hly(maxlen)}')
    print(f'Overall maximum sequence length: {cu.hly(max(maxlens))}')


if __name__ == "__main__":
    app.run(main)
