# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Add Tasks to registry."""

import functools
import json
from re import I

from transformer import text_dataset
import seqio
import t5.data
from t5.data import preprocessors
import tensorflow as tf
from absl import logging
from copy import deepcopy

from seqio.vocabularies import Vocabulary
from transformer.gc_lm_dataformat import GCReader
from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union
import numpy as np
import multiprocessing as mp
from transformers import GPT2TokenizerFast
from bigbench.bbseqio import task_api
from bigbench.bbseqio import tasks as bbtasks
from bigbench.bbseqio import vocabs
from bigbench.bbseqio import bigbench_bridge as bb
from bigbench.api import task_metrics
from seqio import utils
import jax
import jax.numpy as jnp

import glob
import re

DATAPATH = "" # removed for review

class CharVocabulary(Vocabulary):
  """Vocabulary that does table-lookup of unigrams."""

  def __init__(self, unigrams):
    """UnigramVocabulary constructor.
    Args:
      unigrams: the collection of in-vocabulary tokens. This collection should
        not include PAD or UNK, which are automatically assigned ids and managed
        as possible decode tokens.
    """

    super().__init__()
    unigrams_as_list = list(unigrams)
    self._unigram_by_id = ["PAD"] + unigrams_as_list + ["UNK"]
    self._id_by_unigram = {u: i for i, u in enumerate(self._unigram_by_id)}
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(["PAD"] + unigrams_as_list),
        # One extra value because the leading 0 corresponds to PAD
        values=tf.constant(range(len(unigrams) + 1), dtype=tf.int64))
    self._id_by_unigram_tf = tf.lookup.StaticVocabularyTable(initializer,
                                                             num_oov_buckets=1)
    self._unigram_by_id_tf = tf.constant(self._unigram_by_id)

  def _encode(self, s: str):
    return [self._id_by_unigram.get(c,self.unk_id) for c in list(s)]

  def _encode_tf(self, s):
    l = tf.strings.unicode_split(s, "UTF-8")
    tf_ids = self._id_by_unigram_tf.lookup(l)
    return tf.dtypes.cast(tf_ids, tf.int32)

  def _decode(self, ids) -> str:
    return " ".join(self._unigram_by_id[id] for id in ids)

  def _decode_tf(self, ids):
    return self._unigram_by_id_tf[ids[0]]

  @property
  def _base_vocab_size(self):
    return len(self._unigram_by_id)

  @property
  def eos_id(self):
    return None

  @property
  def unk_id(self):
    return self._base_vocab_size - 1




TaskRegistry = seqio.TaskRegistry


def define_pg19_task(name: str, vocab: seqio.Vocabulary):
  seqio.TaskRegistry.add(
      name,
      seqio.TfdsDataSource(
          tfds_name="pg19:0.1.1"
      ),
      preprocessors=[
          functools.partial(text_dataset.rekey_articles,
                            rekey={"book_text": "targets"},
                            keep={"book_title", "book_id", "publication_date"}),
          seqio.preprocessors.tokenize,
      ],
      output_features={
          "targets": seqio.Feature(vocab,
                                   add_eos=False, dtype=tf.int32),
      }
  )



T5_DEFAULT_VOCABULARY = t5.data.get_default_vocabulary()
define_pg19_task("pg19_bytes", seqio.ByteVocabulary())
define_pg19_task("pg19_tokens", T5_DEFAULT_VOCABULARY)


seqio.TaskRegistry.add(
    "c4_lm_v230_cl_standard",
    seqio.TfdsDataSource(tfds_name="c4/en:3.0.1"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
    ],
    output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})


seqio.TaskRegistry.add(
    "wiki40b",
    seqio.TfdsDataSource(tfds_name="wiki40b/en:1.3.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})
@utils.map_over_dataset(num_seeds=2)
def dropout_shuffle(x, seeds, dropout_prob, vocab_size):
  input_tokens = x["targets"]
  shuffled_index = tf.random.stateless_uniform(shape = [vocab_size], seed = seeds[0], minval=0, maxval = vocab_size, dtype=tf.int32)
  vocab_mixing = tf.random.stateless_binomial(shape = [vocab_size], seed = seeds[1], counts = tf.ones([vocab_size]), probs = dropout_prob)
  new_token_index = shuffled_index * vocab_mixing + tf.range(vocab_size) * (1 - vocab_mixing)    
  x["targets"] = tf.gather(new_token_index, input_tokens)
  return x

seqio.TaskRegistry.add(
    "wiki40b_rnd",
    seqio.TfdsDataSource(tfds_name="wiki40b/en:1.3.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        functools.partial(
            dropout_shuffle, 
            dropout_prob = 0.5,
            vocab_size = T5_DEFAULT_VOCABULARY.vocab_size
            ),
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})                                  

seqio.TaskRegistry.add(
    "wiki40b_char",
    seqio.TfdsDataSource(tfds_name="wiki40b/en:1.3.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(CharVocabulary([chr(i) for i in range(128)]),
                                   add_eos=False, dtype=tf.int32),})



class GCFileDataSource(seqio.DataSource):
  """A `DataSource` that reads a file to provide the input dataset."""

  def __init__(self,
               split_to_filepattern: Mapping[str, Union[str, Iterable[str]]],
               num_input_examples: Optional[Mapping[str, int]] = None,
               caching_permitted: bool = True,
               file_shuffle_buffer_size: Optional[int] = None):

    
    def gc_lmd_generator(file): 
        rdr = GCReader(file)
        for doc in rdr.stream_data(threaded = True):
            yield { "targets": doc}

    def read_file_fn(file_dataset):
        out_sig = {
          "targets": tf.TensorSpec(shape=(), dtype=tf.string),
        }

        def generator():
            return gc_lmd_generator(list(s.decode("utf-8") for s in file_dataset.as_numpy_iterator()))


        return tf.data.Dataset.from_generator(generator, output_signature=out_sig)
    
    
    self._split_to_filepattern = split_to_filepattern
    self._reader = read_file_fn
    self._file_shuffle_buffer_size = file_shuffle_buffer_size
    super().__init__(
        splits=split_to_filepattern.keys(),
        num_input_examples=num_input_examples,
        caching_permitted=caching_permitted)

  @property
  def supports_arbitrary_sharding(self) -> bool:
    return False

  def get_dataset(self,
                  split: str,
                  shuffle: bool = True,
                  seed: Optional[int] = None,
                  shard_info: Optional[seqio.ShardInfo] = None) -> tf.data.Dataset:
    files = self.list_shards(split)

    if not files:
      raise ValueError("No file is found for the file pattern: "
                       f"{self._split_to_filepattern[split]}.")
    files_ds = tf.data.Dataset.from_tensor_slices(np.array(files, dtype=str))

    if shard_info:
      if len(files) < shard_info.num_shards:
        raise ValueError(
            f"Dataset has too few files to shard. {len(files)} files vs "
            f"{shard_info.num_shards} shards requested.")
      files_ds = files_ds.shard(shard_info.num_shards, shard_info.index)

    if shuffle:
      if self._file_shuffle_buffer_size:
        logging.warning(
            "`file_shuffle_buffer_size` is explicitly set to %d; this may lead "
            "to an imperfect file shuffle. Leave `file_shuffle_buffer_size` "
            "unset for a perfect shuffle.", self._file_shuffle_buffer_size)
      file_shuffle_buffer_size = self._file_shuffle_buffer_size or len(files)
      files_ds = files_ds.shuffle(
          buffer_size=file_shuffle_buffer_size, seed=seed)

    return self._reader(files_ds)

  def list_shards(self, split: str) :
    return tf.io.gfile.glob(self._split_to_filepattern[split])



split_to_filepattern = {
    "train": DATAPATH+"github_data_25B/github_data/data_[1..9]*",
    "test": DATAPATH+"github_data_25B/github_data/data_0_*"
}

    
seqio.TaskRegistry.add(
    "github_data_25B",
    GCFileDataSource(split_to_filepattern = split_to_filepattern),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})

seqio.TaskRegistry.add(
    "github_data_25B_rnd",
    GCFileDataSource(split_to_filepattern = split_to_filepattern),
    preprocessors=[
        seqio.preprocessors.tokenize,
        functools.partial(
            dropout_shuffle, 
            dropout_prob = 0.5,
            vocab_size = T5_DEFAULT_VOCABULARY.vocab_size
            ),
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})


seqio.TaskRegistry.add(
    "github_data_25B_char",
    GCFileDataSource(split_to_filepattern = split_to_filepattern),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(CharVocabulary([chr(i) for i in range(128)]),
                                   add_eos=False, dtype=tf.int32),})


seqio.MixtureRegistry.add(
  "wiki_code_mix",
  [("wiki40b", 3), ("github_data_25B", 1)],
)

seqio.MixtureRegistry.add(
  "wiki_code_mix_rnd",
  [("wiki40b_rnd", 3), ("github_data_25B_rnd", 1)],
)


class GPT2Vocabulary(seqio.Vocabulary):
  """Vocabulary that does table-lookup of unigrams."""

  def __init__(self):
    """ GPT2Vocabulary from huggingface.
    """

    super().__init__()
    self.gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    self.gpt2_tokenizer.add_special_tokens({'unk_token': '??'})
    self.gpt2_tokenizer.add_special_tokens({'pad_token': ''})

  def _encode(self, s: str) -> Sequence[int]:
    return (self.gpt2_tokenizer.encode(s, return_tensors = "np") + 1).tolist()[0]

  def _encode_tf(self, s: tf.Tensor) -> tf.Tensor:
    def eager_encode_tf(t: tf.Tensor):
         return (self.gpt2_tokenizer.encode(t.numpy().decode("utf-8"), return_tensors = "tf") + 1)[0]

    return tf.py_function(func=eager_encode_tf, inp=[s], Tout=tf.RaggedTensorSpec(shape=[None], dtype=tf.int32))

  def _decode(self, ids: Sequence[int]) -> str:
    ids = np.array(ids) - 1
    ids[ids == -1] = self._base_vocab_size - 1
    return self.gpt2_tokenizer.decode(ids)

  def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor:
    ids = ids.numpy() - 1
    ids[ids == -1] = self._base_vocab_size - 1
    return self.gpt2_tokenizer.decode(ids)


  @property
  def _base_vocab_size(self):
    return self.gpt2_tokenizer.vocab_size + 2

  @property
  def eos_id(self):
    return self._base_vocab_size - 2

  @property
  def unk_id(self):
    return self._base_vocab_size - 1

GPT2_VOCABULARY = GPT2Vocabulary()

seqio.TaskRegistry.add(
    "wiki40b_gpt2_vocab",
    seqio.TfdsDataSource(tfds_name="wiki40b/en:1.3.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(GPT2_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})

seqio.TaskRegistry.add(
    "github_data_25B_gpt2_vocab",
    GCFileDataSource(split_to_filepattern = split_to_filepattern),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(GPT2_VOCABULARY,
                                   add_eos=False, dtype=tf.int32),})

seqio.MixtureRegistry.add(
  "wiki_code_mix_gpt2_vocab",
  [("wiki40b_gpt2_vocab", 3), ("github_data_25B_gpt2_vocab", 1)],
)

LARGER_VOCAB = None

seqio.TaskRegistry.add(
    "wiki40b_64k_vocab",
    seqio.TfdsDataSource(tfds_name="wiki40b/en:1.3.0"),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                   add_eos=False, dtype=tf.int32),})

seqio.TaskRegistry.add(
    "github_data_25B_64k_vocab",
    GCFileDataSource(split_to_filepattern = split_to_filepattern),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                   add_eos=False, dtype=tf.int32),})

seqio.MixtureRegistry.add(
  "wiki_code_mix_64k_vocab",
  [("wiki40b_64k_vocab", 3), ("github_data_25B_64k_vocab", 1)],
)


def feature_to_spec(feature, length=False):
    if isinstance(feature, datasets.ClassLabel):
        return tf.TensorSpec(shape=() if not length else (None if length == -1 else length,), dtype=tf.int64)
    elif isinstance(feature, datasets.Value):
        return tf.TensorSpec(
            shape=() if not length else (None if length == -1 else length,), dtype=getattr(tf.dtypes, feature.dtype)
        )
    elif hasattr(feature, "dtype") and hasattr(feature, "shape"):
        return tf.TensorSpec(shape=feature.shape, dtype=feature.dtype)
    elif isinstance(feature, datasets.Sequence):
        return feature_to_spec(feature.feature, length=feature.length)
    elif isinstance(feature, list):
        return [feature_to_spec(f, length=length) for f in feature]
    elif isinstance(feature, dict):
        return {k: feature_to_spec(v, length=length) for k, v in feature.items()}
    else:
        raise ValueError(f"Unparseable feature type {type(feature)}")


def hf_dataset_to_tf_dataset(dataset):
    return tf.data.Dataset.from_generator(
        dataset.__iter__, output_signature={k: feature_to_spec(v) for k, v in dataset.features.items()}
)

import datasets
def get_tf_dataset(split, shuffle_files):
    # HF datasets does not support file-level shuffling
    del shuffle_files
    dataset = datasets.load_dataset("transformer/the_pile.py", "all", split=split, streaming=True).shuffle(seed=42, buffer_size=10_000)
    return hf_dataset_to_tf_dataset(dataset)

  
def get_TBC_tf_dataset(split, task = "NE"):
    # HF datasets does not support file-level shuffling
    dataset = datasets.load_dataset("cbt", task, split=split, streaming=True).shuffle(seed=42, buffer_size=10_000)
    return hf_dataset_to_tf_dataset(dataset)
    

seqio.TaskRegistry.add(
    "the_pile",
    seqio.FunctionDataSource(get_tf_dataset, splits=["train", "test"]),
    preprocessors=[
        functools.partial(
            preprocessors.rekey, key_map={
                "inputs": None,
                "targets": "text"
            }),
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
    ],
    output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                   add_eos=False, dtype=tf.int32),})

import jsonlines

class GCJsonDataSource(seqio.DataSource):
  """A `DataSource` that reads a file to provide the input dataset."""
  def __init__(self,
                split_to_filepattern: Mapping[str, Union[str, Iterable[str]]],
                num_input_examples: Optional[Mapping[str, int]] = None,
                caching_permitted: bool = True,
                file_shuffle_buffer_size: Optional[int] = None):
     
    def gc_lmd_generator(files):
      for file in files:
        with tf.io.gfile.GFile(file, 'rb+') as f:
            reader = jsonlines.Reader(f)
            for item in reader:
                result = dict()
                result['targets'] = item['text']
                # if item['repetitions'] > 0:
                # result['meta'] = item['meta']
                yield result

    def read_file_fn(file_dataset):

      # return tf.data.Dataset.from_tensor_slices([file_dataset])
      out_sig = {
          "targets": tf.TensorSpec(shape=(), dtype=tf.string),
      }

      def generator():
          return gc_lmd_generator(list(s.decode("utf-8") for s in file_dataset.as_numpy_iterator()))


      return tf.data.Dataset.from_generator(generator, output_signature=out_sig)
        
        
    self._split_to_filepattern = split_to_filepattern
    self._reader = read_file_fn
    self._file_shuffle_buffer_size = file_shuffle_buffer_size
    super().__init__(
        splits=split_to_filepattern.keys(),
        num_input_examples=num_input_examples,
        caching_permitted=caching_permitted)

  @property
  def supports_arbitrary_sharding(self) -> bool:
    return False

  def get_dataset(self,
                  split: str,
                  shuffle: bool = True,
                  seed: Optional[int] = None,
                  shard_info: Optional[seqio.ShardInfo] = None) -> tf.data.Dataset:
    files = self.list_shards(split)

    if not files:
      raise ValueError("No file is found for the file pattern: "
                       f"{self._split_to_filepattern[split]}.")
    files_ds = tf.data.Dataset.from_tensor_slices(np.array(files, dtype=str))

    if shard_info:
      if len(files) < shard_info.num_shards:
        raise ValueError(
            f"Dataset has too few files to shard. {len(files)} files vs "
            f"{shard_info.num_shards} shards requested.")
      files_ds = files_ds.shard(shard_info.num_shards, shard_info.index)

    if shuffle:
      if self._file_shuffle_buffer_size:
        logging.warning(
            "`file_shuffle_buffer_size` is explicitly set to %d; this may lead "
            "to an imperfect file shuffle. Leave `file_shuffle_buffer_size` "
            "unset for a perfect shuffle.", self._file_shuffle_buffer_size)
      file_shuffle_buffer_size = self._file_shuffle_buffer_size or len(files)
      files_ds = files_ds.shuffle(
          buffer_size=file_shuffle_buffer_size, seed=seed)

    return self._reader(files_ds)

  def list_shards(self, split: str) :
    return tf.io.gfile.glob(self._split_to_filepattern[split])


import tensorflow_datasets as tfds
from transformer import tfds_pile

def gcs_dataset_fn(split, shuffle_files):
  ds = tfds.load(name="ThePile:1.0.0", try_gcs=True, data_dir=DATAPATH+"ThePile")
  return ds[split].shuffle(seed=42, buffer_size=10_000)

seqio.TaskRegistry.add(
  "the_pile_gcs",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                  add_eos=False, dtype=tf.int32),})

seqio.TaskRegistry.add(
  "the_pile_gcs_char",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(CharVocabulary([chr(i) for i in range(128)]),
                                  add_eos=False, dtype=tf.int32),})    

seqio.TaskRegistry.add(
  "the_pile_gcs_1k",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(seqio.SentencePieceVocabulary(DATAPATH+"vocabs/1000.model"),
                                  add_eos=False, dtype=tf.int32),})    

seqio.TaskRegistry.add(
  "the_pile_gcs_5k",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(seqio.SentencePieceVocabulary(DATAPATH+"vocabs/5000.model"),
                                  add_eos=False, dtype=tf.int32),})    


seqio.TaskRegistry.add(
  "the_pile_gcs_32k",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(T5_DEFAULT_VOCABULARY,
                                  add_eos=False, dtype=tf.int32),})                                                            
  
# seqio.TaskRegistry.add(
#   "the_filtered_pile",
#   GCJsonDataSource(split_to_filepattern = {"train": DATAPATH+"filtered_pile/*"}),
#   preprocessors=[
#       seqio.preprocessors.tokenize,
#       # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
#   ],
#   output_features={"targets": seqio.Feature(LARGER_VOCAB,
#                                   add_eos=False, dtype=tf.int32),})

def filter_math(ds):
  return ds.filter(lambda ex: ex['meta'] ==  'ArXiv') 

seqio.TaskRegistry.add(
  "the_filtered_pile",
  seqio.FunctionDataSource(gcs_dataset_fn, splits=["train", "test"]),
  preprocessors=[
      filter_math, 
      functools.partial(
        preprocessors.rekey, key_map={
            "inputs": None,
            "targets": "text"
        }),   
      seqio.preprocessors.tokenize,
      # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
  ],
  output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                  add_eos=False, dtype=tf.int32),})


########## Add Bigbench data ############

ALL_INDEXED_TASKS = {}
ALL_METRICS = {}

ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"

def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS

def is_correct(model_completion, gt_answer):
    gt_answer = extract_answer(gt_answer)
    assert gt_answer != INVALID_ANS
    return extract_answer(model_completion) == gt_answer
  
def get_masked_tokens(token, loss_mask):
  length = token.shape[0]
  loss_mask = (jnp.cumsum(loss_mask) > 0).astype(int)
  masked_token = jnp.pad(token, [(0, 1)])[jnp.where(loss_mask==1, size = length, fill_value = length)]
  return masked_token

def evaluate_special_metric(gen_tokens, target_tokens, loss_masks, task_name):
  task = seqio.get_mixture_or_task(task_name)
  vocab = task.output_features["targets"].vocabulary
  metric = ALL_METRICS[task_name]
  num_examples = gen_tokens.shape[0]

  predictions = jax.vmap(get_masked_tokens)(gen_tokens, loss_masks)
  targets = jax.vmap(get_masked_tokens)(target_tokens, loss_masks)
  predictions = [vocab.decode(int(w) for w in s) for s in predictions]
  targets = [[vocab.decode(int(w) for w in s)] for s in targets]
    
  if metric == "bleu":
    return task_metrics.bleu_fn(targets, predictions)["bleu"]
  elif metric == "rougeLsum":
    return task_metrics.rouge_fn(targets, predictions)["rougeLsum"]
  elif metric == "GSM8K_score":
    return sum([float(is_correct(predictions[i], targets[i][0])) for i in range(num_examples)]) / num_examples
  else:
    raise NotImplementedError


@utils.map_over_dataset
def add_const_entries(x, entries):
  for k, v in entries.items():
    x[k] = v
  return x

def filter_correct_choice(ds):
  return ds.filter(lambda ex: ex['scores'] == 1 if 'scores' in ex else True) 

def filter_too_long(sequence_length, ds):
  return ds.filter(lambda ex: (len(ex['inputs']) + len(ex['targets'])) < sequence_length - 1) 

@utils.map_over_dataset
def encode_decode(x, vocab):
  x["answers"] = vocab.decode_tf(vocab.encode_tf(x["answers"]).to_tensor())
  return x


def modify_bigbench_tasks(subtasks, task_name, sequence_length, metric = None, eval = True):
  assert task_name not in ALL_INDEXED_TASKS
  ALL_INDEXED_TASKS[task_name] = []
  if metric:
    ALL_METRICS[task_name] = metric
  new_subtasks_names = []
  for t in subtasks:
    task = seqio.TaskRegistry.get(t)
    original_ps = list(task.preprocessors)
    if eval:
      p = [
        # functools.partial(
        #   encode_decode,
        #   vocab=task.output_features["targets"].vocabulary
        # ),
        functools.partial(
          filter_too_long,
          sequence_length
        ),
        filter_correct_choice,
        functools.partial(
          seqio.DecoderFeatureConverter(pack=False),
          task_feature_lengths={
              "inputs": sequence_length,
              "targets": sequence_length,
          }
        ),
        functools.partial(
          preprocessors.rekey, key_map={
              "loss_mask": "decoder_loss_weights",
              "targets": "decoder_input_tokens"
          }
        ),
        functools.partial(
          add_const_entries, entries={
            "task_idx": len(ALL_INDEXED_TASKS[task_name]),
            "start_of_sequence": True,
            "epoch": 0,
          }
        ),
      ]
      ALL_INDEXED_TASKS[task_name].append(t)
      # task._preprocessors = tuple(original_ps + p)
      output_features = deepcopy(task.output_features)
      output_features["loss_mask"] =  seqio.Feature(seqio.PassThroughVocabulary(256, 0), add_eos=False, dtype=tf.int32)
      output_features["task_idx"] =  seqio.Feature(seqio.PassThroughVocabulary(256, 0), add_eos=False, dtype=tf.int32, rank = 0)
      output_features["start_of_sequence"] =  seqio.Feature(seqio.PassThroughVocabulary(256, 0), add_eos=False, dtype=tf.bool, rank = 0)
      output_features["epoch"] =  seqio.Feature(seqio.PassThroughVocabulary(256, 0), add_eos=False, dtype=tf.int32, rank = 0)
      register_subtask_name = "eval_" + t
    else:
      p = [
        filter_correct_choice,
        functools.partial(
          seqio.DecoderFeatureConverter(pack=False),
          task_feature_lengths={
              "targets": sequence_length,
          },
        ),
        functools.partial(
          preprocessors.rekey, key_map={
              "loss_mask": "decoder_loss_weights", # currently not used, but could
              "targets": "decoder_input_tokens"
          }
        ),
        functools.partial(
            dropout_shuffle, 
            dropout_prob = 0.5,
            vocab_size = task.output_features["targets"].vocabulary.vocab_size
        ),
      ]
      output_features = {"targets": task.output_features["targets"]}
      register_subtask_name = "train_" + t
    seqio.TaskRegistry.add(
      register_subtask_name,
      task.source,
      preprocessors=tuple(original_ps + p),
      output_features = output_features,
      postprocess_fn = task.postprocessor,
      metric_fns = task.metric_fns,
    )
    new_subtasks_names.append(register_subtask_name)
  return new_subtasks_names

def register_all_bigbench_tasks(num_shots, custom_vocab):
  # Register all BIG-bench JSON tasks.
  custom_vocab_all_json_mix_name = bbtasks.register_bigbench_json(num_shots, custom_vocab)
  bb_mix = seqio.get_mixture_or_task(custom_vocab_all_json_mix_name)
      
  all_subtasks = sorted([t.name for t in bb_mix.tasks])
  return all_subtasks

def register_list_function(all_subtasks, task_name, sequence_length):
  list_functions_tasks = list(filter(lambda x: x.startswith("bigbench:list_functions.gen") & (x.split(".")[-1].startswith("c")), all_subtasks))
  list_functions_tasks = modify_bigbench_tasks(list_functions_tasks, task_name, sequence_length)
  seqio.MixtureRegistry.add(
    task_name,
    zip(list_functions_tasks, [1] * len(list_functions_tasks)),
  )

def has_cache_dir(task_name, split="all"):
  cur_cache_dir = seqio.get_mixture_or_task(task_name).cache_dir
  return not cur_cache_dir is None

def is_nonzero_length_cached_dataset(task_name, cached_split = "all"):
  # if in cached mode, assume that the cache exists and check if length is nonzero
  # check json instead of retrieving iterator
  cur_cache_dir = seqio.get_mixture_or_task(task_name).cache_dir
  with tf.io.gfile.GFile(utils.get_cached_info_path(cur_cache_dir, cached_split)) as f:
    split_info = json.load(f)
    if split_info == {}:
      return False 
    return True 

def is_nonzero_length_dataset(task_name):
  ds = seqio.get_mixture_or_task(task_name).get_dataset(split = "all", sequence_length = None)
  for _ in ds:
    return True
  logging.warning(f"{task_name} has zero examples after filter to current sequence length. So it is excluded.")
  return False


def register_subset_tasks(all_subtasks, selected_tasks, task_name, sequence_length, metric = None, modify = True, eval = True):
  subset_tasks = list(filter(lambda x: any(x.startswith(t) for t in selected_tasks), all_subtasks))  
  if modify:
    subset_tasks = modify_bigbench_tasks(subset_tasks, task_name, sequence_length, metric, eval)
  # subset_tasks = list(filter(is_nonzero_length_dataset, subset_tasks))
  # if len(subset_tasks) == 0:
  #   logging.warning(f"{task_name} has no datasets??? Failed to register")
  #   return
  seqio.MixtureRegistry.add(
    task_name,
    zip(subset_tasks, [1] * len(subset_tasks)),
  )
  print("registered: ", task_name)
  print("examples: ", subset_tasks[:3])
  

emergence_tasks = [
    ### multiple choice tasks are not currently supported during training, 
    ### since it is difficult to regroup after shuffling
    ### but the following could be casted as generation task
    # "bigbench:common_morpheme.mul",
    # "bigbench:hindu_knowledge.mul",
    # "bigbench:logical_args.mul",
    # "bigbench:phrase_relatedness.mul",
    # "bigbench:physical_intuition.mul",
    # "bigbench:social_iqa.mul",
    # "bigbench:strange_stories.mul",
    # "bigbench:swahili_english_proverbs.mul",
    ### generation tasks
    "bigbench:codenames.gen",
    "bigbench:gender_inclusive_sentences_german.gen",
    "bigbench:international_phonetic_alphabet_transliterate.gen",
    "bigbench:modified_arithmetic.gen",
    "bigbench:repeat_copy_logic.gen",
    "bigbench:word_sorting.gen",
    "bigbench:word_unscrambling.gen"
  ]

arithmatics_tasks = [
    'arithmetic',
    'chinese_remainder_theorem', 
    'dyck_languages',
    'dynamic_counting', 
    'elementary_math_qa', 
    'evaluating_information_essentiality', 
    'intersect_geometry', 
    'matrixshapes', 
    'modified_arithmetic', 
    'multistep_arithmetic', 
    'navigate',
    'repeat_copy_logic', 
    'simple_arithmetic', 
    'simple_arithmetic_json', 
    'simple_arithmetic_json_multiple_choice', 
    'simple_arithmetic_json_subtasks', 
    'simple_arithmetic_multiple_targets_json', 
    'simple_text_editing', 
    'unit_conversion', 
    'unit_interpretation', 
    # 'unnatural_in_context_learning', 
    'vitaminc_fact_verification',
    'roots_optimization_and_games'
]

proofs_tasks = [
   'identify_math_theorems', 
   'key_value_maps', 
   'mathematical_induction'
]

decomposition_tasks = [
  'analytic_entailment', 
  'evaluating_information_essentiality', 
  'forecasting_subquestions', 
  'tracking_shuffled_objects'

]

algorithms = [
  'cs_algorithms', 
  'list_functions', 
  'long_context_integration', 
  'sudoku', 
  'word_sorting'
]

def register_custom_tasks(task_names_and_shots, vocab, data_path = "transformer/datasets"):
  seqio_task_names = []
  for task_name, num_shots in task_names_and_shots:
    subtasks = [s.split("/")[-2] for s in glob.glob(data_path + "/" + task_name + "/*/")]
    if len(subtasks) > 0 :
      for subtask in subtasks:
        seqio_task_name = task_api.register_seqio_task(
          bigbench_task_name=task_name,
          bigbench_task_path=data_path + "/" + task_name + "/task.json",
          bigbench_task_type=bb.BigBenchTaskType.GENERATIVE,
          vocab=vocab,
          num_shots=num_shots,
          bigbench_subtask_name=subtask,
          max_examples=None,
        )
        seqio_task_names.append(seqio_task_name)
    else:
      seqio_task_name = task_api.register_seqio_task(
        bigbench_task_name=task_name,
        bigbench_task_path=data_path + "/" + task_name + "/task.json",
        bigbench_task_type=bb.BigBenchTaskType.GENERATIVE,
        vocab=vocab,
        num_shots=num_shots,
        bigbench_subtask_name=None,
        max_examples=None,
      )
      seqio_task_names.append(seqio_task_name)
  return seqio_task_names


# T5
def register_all_bigbench_eval_tasks(sequence_length, vocab = T5_DEFAULT_VOCABULARY, vocab_name = "t5_default"):

  custom_vocab = task_api.SeqIOVocabulary(name=vocab_name, description=vocab_name, vocabulary=vocab)
  all_bigbench_tasks_1shots = register_all_bigbench_tasks(1, custom_vocab)
  all_bigbench_tasks_2shots = register_all_bigbench_tasks(2, custom_vocab)
  all_bigbench_tasks_3shots = register_all_bigbench_tasks(3, custom_vocab)

  # register_list_function(all_bigbench_tasks_2shots, f"bigbench:2shot_list_functions_mix_{sequence_length}", sequence_length)
  # register_list_function(all_bigbench_tasks_3shots, f"bigbench:3shot_list_functions_mix_{sequence_length}", sequence_length)
  all_tasks = all_bigbench_tasks_2shots + all_bigbench_tasks_3shots
  # s = set([ tuple(x.split(".")[:2]) for x in all_tasks ] )
  # un = {}
  # for x in s:
  #   if x[0] not in un:
  #     un[x[0]] = x[1]
  #   else:
  #     if x[1] == "gen":
  #       un[x[0]] = x[1]
  #     elif x[1] == "mul" and un[x[0]] == "scr":
  #       un[x[0]] = x[1]
  # half_tasks = sorted([ ".".join(x) for x in un.items()])[::2]
  # TOFIX: too slow
  # train_tasks = [
  #   "bigbench:conlang_translation.gen",
  #   "bigbench:linguistic_mappings.gen",
  #   "bigbench:language_games.gen",
  #   "bigbench:natural_instructions.gen",
  #   "bigbench:unnatural_in_context_learning.gen",
  # ]
  # register_subset_tasks(all_tasks, train_tasks, f"bigbench:train_tasks_mix_{sequence_length}", sequence_length, eval = False)
  # register_subset_tasks(all_tasks, train_tasks, f"bigbench:train_tasks_eval_mix_{sequence_length}", sequence_length, eval = True)

  register_subset_tasks(all_tasks, [ 
    "bigbench:arithmetic", 
    "bigbench:modified_arithmetic",
    "bigbench:repeat_copy_logic",
    "bigbench:elementary_math_qa",
    ], f"bigbench:arithmatics_tasks_mix_{sequence_length}", sequence_length)
  register_subset_tasks(all_tasks, [ 
    "bigbench:" + x for x in proofs_tasks
    ], f"bigbench:proofs_tasks_mix_{sequence_length}", sequence_length)

  # register_subset_tasks(all_bigbench_tasks_2shots+ all_bigbench_tasks_3shots, ["bigbench:conlang_translation.gen"], f"bigbench:conlang_translation_mix_{sequence_length}", sequence_length, "rougeLsum")
  # register_subset_tasks(all_bigbench_tasks_2shots+ all_bigbench_tasks_3shots, ["bigbench:linguistics_puzzles.gen"], f"bigbench:linguistics_puzzles_mix_{sequence_length}", sequence_length)
  # register_subset_tasks(all_bigbench_tasks_2shots+ all_bigbench_tasks_3shots, ["bigbench:language_games.gen"], f"bigbench:language_games_mix_{sequence_length}", sequence_length, "bleu")

  custom_tasks_names = register_custom_tasks([("wmt-0shot", 0), ("wmt", 1), ("wmt", 5),
                        ("code_translation" , 1), ("code_translation" , 2), ("code_translation" , 3),
                        ("GSM8K" , 1), ("GSM8K" , 2), ("GSM8K" , 3),
                        ], custom_vocab)

  # register_subset_tasks(custom_tasks_names, ["bigbench:wmt"], f"bigbench:wmt_mix_{sequence_length}", sequence_length, "bleu")
  register_subset_tasks(custom_tasks_names, ["bigbench:code_translation"], f"bigbench:code_translation_mix_{sequence_length}", sequence_length, "bleu")
  # register_subset_tasks(custom_tasks_names, ["bigbench:GSM8K"], f"bigbench:GSM8K_mix_{sequence_length}", sequence_length, "GSM8K_score")
  return all_bigbench_tasks_1shots, all_bigbench_tasks_2shots, all_bigbench_tasks_3shots




from bigbench.api import json_task
from typing import List, Any, Dict, Union, Optional
import numpy as np
import tensorflow.compat.v2 as tf
from bigbench.bbseqio import bigbench_json_paths as bb_json_paths


def custom_format_fn(
    sample: Dict[str, Any],
    input_prefix: str,
    output_prefix: str,
    choice_prefix: str,
    random_seed: int,
    append_choices_to_input: bool = True,
) -> Dict[str, Any]:
    """Default format for tasks.
    Args:
      sample: Dictionary with an 'input' entry and a 'target' or 'target_scores
        entry (or both), describing a single example.
      input_prefix: input prefix, prepended to all inputs.
      output_prefix: output prefix, prepended to outputs and choices (if present).
      choice_prefix: prefix prepended to each choice in a multiple-choice question.
      random_seed: random number generator seed
      append_choices_to_input: append choices to input for multiple choice.
    Returns:
      sample: Formatted dictionary, with 'choice' key added if present in input.
    Raises:
      Exception: If output not in choices.
    """

    def input_format(text):
        return input_prefix + text

    rng = np.random.RandomState(seed=random_seed)

    if "target_scores" in sample:
        choice_dic = sample["target_scores"]
        if append_choices_to_input:
            permuted_choices = rng.permutation(sorted(list(choice_dic.keys())))
            sample["input"] = (
                sample["input"] + choice_prefix + choice_prefix.join(permuted_choices)
            )
        if "target" not in sample:
            max_score = max(choice_dic.values())  # type: ignore
            min_score = min(choice_dic.values())
            if (len(choice_dic) == 1 and max_score == 0) or (
                len(choice_dic) > 1 and max_score <= min_score
            ):  # Correct choices should have higher score.
                raise ValueError(f"No correct choice provided in example: {sample}")
            # Target corresponds to maximum score.
            # If multiple choices have same score it will chose the first one.
            sample["target"] = [k for k, v in choice_dic.items() if v == max_score][
                0
            ]  # type: ignore
        sample["choice"] = list(sample["target_scores"].keys())

   
    if isinstance(sample["target"], list):
        sample["target"] = rng.choice(sample["target"])
    sample["target"] = input_format(sample["input"]) + output_prefix + sample["target"]

    return sample

def task_to_text(json_task):
    np.random.shuffle(json_task._ds)
    separator = json_task.task_data.get(
            "few_shot_example_separator", "\n"
        ) 
    task_prefix = str(json_task.task_data.get("task_prefix", ""))
    return task_prefix + separator.join( ex["target"] for ex in json_task._ds)
    
def get_dataset_fn(selected_task_names, format_fn):
    def dataset_fn(split, shuffle_files):
        json_tasks = [json_task.JsonTask(bb_json_paths.get_task_path(task_name)[0], format_fn= format_fn)  for task_name in selected_task_names]
        
        def gen():
            shuffled_order = np.arange(len(json_tasks))
            np.random.shuffle(shuffled_order)
            while True:
                for idx in shuffled_order:
                  t = json_tasks[idx]
                  text = ""
                  if len(t.subtasks) == 0:
                      text = task_to_text(t)
                      task_name = t.name
                      yield {"targets": tf.constant(text),  "task_name": tf.constant(task_name)}
                  else:
                    s_shuffled_order = np.arange(len(t.subtasks))
                    np.random.shuffle(shuffled_order)
                    for idx_s in s_shuffled_order:
                        subtask = t.subtasks[idx_s]
                        # subtask = np.random.choice(t.subtasks)
                        text = task_to_text(subtask) 
                        task_name = subtask.name
                        yield {"targets": tf.constant(text),  "task_name": tf.constant(task_name)}
                  
        ds = tf.data.Dataset.from_generator(
            gen,
            output_signature={
                "task_name": tf.TensorSpec(shape=(), dtype=tf.string),
                "targets": tf.TensorSpec(shape=(), dtype=tf.string)
            }
        )
        return ds
    return dataset_fn


selected_task_names = [
#  'abstract_narrative_understanding',
#  'anachronisms',
#  'analogical_similarity',
#  'analytic_entailment',
#  'arithmetic',
#  'ascii_word_recognition',
 'authorship_verification',
#  'auto_categorization',
#  'auto_debugging',
#  'bbq_lite_json',
#  'bridging_anaphora_resolution_barqa',
 'causal_judgment',
 'cause_and_effect',
#  'checkmate_in_one',
#  'chess_state_tracking',
#  'chinese_remainder_theorem',
#  'cifar10_classification',
 'code_line_description',
#  'codenames',
#  'color',
 'common_morpheme',
#  'conceptual_combinations',
 'conlang_translation',
#  'contextual_parametric_knowledge_conflicts',
#  'crash_blossom',
#  'crass_ai',
#  'cryobiology_spanish',
#  'cryptonite',
#  'cs_algorithms',
#  'dark_humor_detection',
#  'date_understanding',
#  'disambiguation_qa',
#  'discourse_marker_prediction',
#  'disfl_qa',
 'dyck_languages',
#  'elementary_math_qa',
#  'emoji_movie',
#  'emojis_emotion_prediction',
#  'empirical_judgments',
 'english_proverbs',
 'english_russian_proverbs',
#  'entailed_polarity',
#  'entailed_polarity_hindi',
#  'epistemic_reasoning',
#  'evaluating_information_essentiality',
 'fact_checker',
#  'fantasy_reasoning',
 'few_shot_nlg',
#  'figure_of_speech_detection',
#  'formal_fallacies_syllogisms_negation',
#  'gem',
 'gender_inclusive_sentences_german',
 'general_knowledge',
#  'geometric_shapes',
#  'goal_step_wikihow',
#  'gre_reading_comprehension',
#  'hhh_alignment',
#  'hindi_question_answering',
#  'hindu_knowledge',
#  'hinglish_toxicity',
#  'human_organs_senses',
#  'hyperbaton',
#  'identify_math_theorems',
#  'identify_odd_metaphor',
#  'implicatures',
#  'implicit_relations',
#  'indic_cause_and_effect',
#  'intent_recognition',
#  'international_phonetic_alphabet_nli',
#  'international_phonetic_alphabet_transliterate',
#  'intersect_geometry',
#  'irony_identification',
#  'kanji_ascii',
#  'kannada',
#  'key_value_maps',
#  'known_unknowns',
 'language_games',
 'language_identification',
 'linguistic_mappings',
 'linguistics_puzzles',
 'list_functions',
#  'logic_grid_puzzle',
#  'logical_args',
#  'logical_deduction',
#  'logical_fallacy_detection',
#  'logical_sequence',
#  'mathematical_induction',
#  'matrixshapes',
#  'medical_questions_russian',
#  'metaphor_boolean',
#  'metaphor_understanding',
#  'minute_mysteries_qa',
#  'misconceptions',
#  'misconceptions_russian',
#  'mnist_ascii',
#  'modified_arithmetic',
#  'moral_permissibility',
#  'movie_dialog_same_or_different',
#  'movie_recommendation',
#  'mult_data_wrangling',
#  'multiemo',
 'natural_instructions',
#  'navigate',
 'nonsense_words_grammar',
#  'novel_concepts',
 'object_counting',
#  'odd_one_out',
#  'operators',
 'paragraph_segmentation',
 'parsinlu_qa',
 'parsinlu_reading_comprehension',
 'penguins_in_a_table',
 'periodic_elements',
 'persian_idioms',
 'phrase_relatedness',
 'physical_intuition',
 'physics',
 'physics_questions',
#  'play_dialog_same_or_different',
#  'polish_sequence_labeling',
#  'presuppositions_as_nli',
 'qa_wikidata',
 'question_selection',
 'real_or_fake_text',
#  'reasoning_about_colored_objects',
 'repeat_copy_logic',
 'rephrase',
#  'riddle_sense',
#  'ruin_names',
#  'salient_translation_error_detection',
 'scientific_press_release',
#  'semantic_parsing_in_context_sparc',
#  'semantic_parsing_spider',
#  'sentence_ambiguity',
#  'similarities_abstraction',
#  'simp_turing_concept',
#  'simple_arithmetic_json',
#  'simple_arithmetic_json_multiple_choice',
#  'simple_arithmetic_json_subtasks',
#  'simple_arithmetic_multiple_targets_json',
 'simple_ethical_questions',
#  'simple_text_editing',
#  'snarks',
 'social_iqa',
 'social_support',
 'sports_understanding',
#  'strange_stories',
#  'strategyqa',
#  'sufficient_information',
#  'suicide_risk',
 'swahili_english_proverbs',
 'swedish_to_german_proverbs',
#  'symbol_interpretation',
#  'temporal_sequences',
#  'tense',
#  'timedial',
#  'topical_chat',
#  'tracking_shuffled_objects',
 'understanding_fables',
 'undo_permutation',
#  'unit_conversion',
#  'unit_interpretation',
 'unnatural_in_context_learning',
#  'vitaminc_fact_verification',
#  'what_is_the_tao',
 'which_wiki_edit',
#  'winowhy',
 'word_sorting',
 'word_unscrambling'
]

all_task_names = bb_json_paths.get_task_names()

selected_task_names = list(set(all_task_names) - set(arithmatics_tasks) - set(proofs_tasks))

seqio.TaskRegistry.add(
    "train_bigbench_64k_vocab",
    seqio.FunctionDataSource(
        get_dataset_fn(
            selected_task_names,
            format_fn=custom_format_fn,
        ), splits = ["train", "test"]),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
        # functools.partial(
        #     dropout_shuffle, 
        #     dropout_prob = 0.5,
        #     vocab_size = LARGER_VOCAB.vocab_size
        # ),
    ],
    output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                   add_eos=False, dtype=tf.int32),})

seqio.MixtureRegistry.add(
    "wiki_code_mix_64k_vocab_bigbench",
    
    [("wiki_code_mix_64k_vocab", 10), ("train_bigbench_64k_vocab", 1)],
)
seqio.MixtureRegistry.add(
    "the_pile_64k_vocab_bigbench",
    [("the_pile", 10000), ("train_bigbench_64k_vocab", 1)],
)


seqio.TaskRegistry.add(
    "train_all_bigbench_64k_vocab",
    seqio.FunctionDataSource(
        get_dataset_fn(
            all_task_names,
            format_fn=custom_format_fn,
        ), splits = ["train", "test"]),
    preprocessors=[
        seqio.preprocessors.tokenize,
        # t5.data.preprocessors.concatenate_and_split_to_fixed_length,
        # functools.partial(
        #     dropout_shuffle, 
        #     dropout_prob = 0.5,
        #     vocab_size = LARGER_VOCAB.vocab_size
        # ),
    ],
    output_features={"targets": seqio.Feature(LARGER_VOCAB,
                                   add_eos=False, dtype=tf.int32),})


seqio.MixtureRegistry.add(
    "wiki_code_mix_64k_vocab_all_bigbench",
    [("wiki_code_mix_64k_vocab", 10), ("train_all_bigbench_64k_vocab", 1)],
)

seqio.MixtureRegistry.add(
    "the_pile_64k_vocab_all_bigbench",
    [("the_pile", 1000), ("train_all_bigbench_64k_vocab", 1)],
)

#############################

seqio.MixtureRegistry.add(
    "the_pile_gcs_and_filtered",
    [("the_pile_gcs", 1), ("the_filtered_pile", 1)],
)