"""Counts the number of examples that have a sequence length greater than some threshold(s)."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from npeff_torch.util import tokenizer_utils

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_list('threshold', [], '')
flags.DEFINE_integer('max_sequence_length', None, '')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('load_dataset_fn_kwargs', None, '')

flags.DEFINE_string('n_examples', None, 'Leave None to do the whole dataset.')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _read_thresholds():
    return [int(i) for i in FLAGS.threshold]

        
###############################################################################


@torch.no_grad()
def main(_):
    assert len(FLAGS.threshold) > 0

    thresholds = _read_thresholds()
    assert all(t < FLAGS.max_sequence_length for t in thresholds)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.max_sequence_length,
        **_read_flag_kwargs(FLAGS.load_dataset_fn_kwargs)
    )

    if FLAGS.n_examples is not None:
        ds = ds.take(FLAGS.n_examples)

    dataloader = DataLoader(ds, batch_size=1)

    total_examples_count = 0
    counts = {t: 0 for t in thresholds}

    for example in tqdm(dataloader):
        total_examples_count += 1

        attention_mask = torch.squeeze(example['attention_mask'], dim=0)
        assert len(attention_mask.shape) == 1

        seqlen = int((attention_mask != 0).type(torch.int64).sum().detach().cpu().numpy())
        
        for t in thresholds:
            if seqlen > t:
                counts[t] += 1

    for t in sorted(counts.keys()):
        print(f'{t}: {counts[t]}')

    print(f'n_total_examples: {total_examples_count}')


if __name__ == "__main__":
    app.run(main)
