"""Custom T5 Tasks for fill-in-the-blank.
"""

import functools
import os

import t5
from t5.data import preprocessors as t5_preprocessors
from t5.evaluation import metrics as t5_metrics

import utils

TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask
MixtureRegistry = t5.data.MixtureRegistry


################################ TRAIN TASKS ###################################

# Fill-in-the-blank task without any conditioning on words from the blank.
TaskRegistry.add(
    "c4_fitb_cond0",
    TfdsTask,
    tfds_name="c4/en:3.0.1",
    text_preprocessor=[
        functools.partial(t5_preprocessors.random_split_text,
                          min_words_per_segment=256,
                          max_words_per_segment=512,
                          text_key="text"),
        functools.partial(custom_preprocessors.fitb_sized_maybe_with_sample,
                          size_bins=[1, 2, 4, 8, 16, 32, 64],
                          max_words_in_sample=0,
                          ensure_blank_at_end=False,
                          text_key="text")],
    postprocess_fn=utils.add_example_postprocess_fn,
    metric_fns=[utils.wrap_metric(t5_metrics.accuracy),
                utils.wrap_metric(t5_metrics.bleu)])

# Fill-in-the-blank task conditioning on sub-seq of max 10 words from the blank.
TaskRegistry.add(
    "c4_fitb_condmax10",
    TfdsTask,
    tfds_name="c4/en:3.0.1",
    text_preprocessor=[
        functools.partial(t5_preprocessors.random_split_text,
                          min_words_per_segment=256,
                          max_words_per_segment=512,
                          text_key="text"),
        functools.partial(custom_preprocessors.fitb_sized_maybe_with_sample,
                          size_bins=[1, 2, 4, 8, 16, 32, 64],
                          max_words_in_sample=10,
                          ensure_blank_at_end=False,
                          text_key="text")],
    postprocess_fn=utils.add_example_postprocess_fn,
    metric_fns=[utils.wrap_metric(t5_metrics.accuracy),
                utils.wrap_metric(t5_metrics.bleu)])

# Fill-in-the-end task without any conditioning on words from the end.
TaskRegistry.add(
    "c4_fite_cond0",
    TfdsTask,
    tfds_name="c4/en:3.0.1",
    text_preprocessor=[
        functools.partial(t5_preprocessors.random_split_text,
                          min_words_per_segment=256,
                          max_words_per_segment=512,
                          text_key="text"),
        functools.partial(custom_preprocessors.fitb_sized_maybe_with_sample,
                          size_bins=[1, 2, 4, 8, 16, 32, 64],
                          max_words_in_sample=0,
                          ensure_blank_at_end=True,
                          text_key="text")],
    postprocess_fn=utils.add_example_postprocess_fn,
    metric_fns=[utils.wrap_metric(t5_metrics.accuracy),
                utils.wrap_metric(t5_metrics.bleu)])

# Fill-in-the-end task with conditioning on sub-seq of max 10 words from end.
TaskRegistry.add(
    "c4_fite_condmax10",
    TfdsTask,
    tfds_name="c4/en:3.0.1",
    text_preprocessor=[
        functools.partial(t5_preprocessors.random_split_text,
                          min_words_per_segment=256,
                          max_words_per_segment=512,
                          text_key="text"),
        functools.partial(custom_preprocessors.fitb_sized_maybe_with_sample,
                          size_bins=[1, 2, 4, 8, 16, 32, 64],
                          max_words_in_sample=10,
                          ensure_blank_at_end=True,
                          text_key="text")],
    postprocess_fn=utils.add_example_postprocess_fn,
    metric_fns=[utils.wrap_metric(t5_metrics.accuracy),
                utils.wrap_metric(t5_metrics.bleu)])


################################# MIXTURES #####################################

MixtureRegistry.add(
    "c4_fitb_plus_fite",
    [
        "c4_fitb_condmax10", "c4_fitb_cond0",
        "c4_fite_condmax10", "c4_fite_cond0"
    ],
    default_rate=1.0)

MixtureRegistry.add(
    "c4_fitb",
    ["c4_fitb_condmax10", "c4_fitb_cond0",],
    default_rate=1.0)

MixtureRegistry.add(
    "c4_fite",
    ["c4_fite_condmax10", "c4_fite_cond0",],
    default_rate=1.0)
