import seqio
from bigbench.bbseqio import task_api
from bigbench.bbseqio import tasks as bbtasks

import tensorflow as tf
import functools
import sys

def filter_too_long(sequence_length, ds):
  return ds.filter(lambda ex: (len(ex['inputs']) + len(ex['targets'])) < sequence_length - 1) 


def modify_task_preprocessors(subtasks, sequence_length):
  # simplified version of tasks.modify_bigbench_tasks 
  # does not interact with tasks.ALL_INDEXED_TASKS
  new_subtasks_names = []
  for t in subtasks:
    task = seqio.TaskRegistry.get(t)
    original_ps = list(task.preprocessors)
    p = [
      functools.partial(
        filter_too_long,
        sequence_length
      ),
    ]      
    seqio.TaskRegistry.add(
      t + f"_filtered_{sequence_length}",
      task.source,
      preprocessors=tuple(original_ps + p),
      output_features = task.output_features,
      postprocess_fn = task.postprocessor,
      metric_fns = task.metric_fns,
    )
    new_subtasks_names.append(t + f"_filtered_{sequence_length}")
  return new_subtasks_names


def register_all_bigbench_tasks(num_shots, custom_vocab):
  # Register all BIG-bench JSON tasks.
  custom_vocab_all_json_mix_name = bbtasks.register_bigbench_json(num_shots, custom_vocab)
  bb_mix = seqio.get_mixture_or_task(custom_vocab_all_json_mix_name)
      
  all_subtasks = sorted([t.name for t in bb_mix.tasks])
  return all_subtasks

def register_cache_tasks(sequence_length, shot, vocab, vocab_name, begin = 0, end = 1, have_all = False):
  # previously used tasks.register_all_bigbench_eval_tasks
  custom_vocab = task_api.SeqIOVocabulary(name=vocab_name, description=vocab_name, vocabulary=vocab)
  all_subtasks = register_all_bigbench_tasks(shot, custom_vocab)
  # all_bigbench_tasks_2shots = register_all_bigbench_tasks(2, custom_vocab)
  # all_bigbench_tasks_3shots = register_all_bigbench_tasks(3, custom_vocab)

  # all_subtasks = all_bigbench_tasks_1shots + all_bigbench_tasks_2shots + all_bigbench_tasks_3shots

  # custom_tasks_names = tasks.register_custom_tasks([("wmt-0shot", 0), ("wmt", 1), ("wmt", 5),
  #                       ("code_translation" , 1), ("code_translation" , 2), ("code_translation" , 3),
  #                       ("GSM8K" , 1), ("GSM8K" , 2), ("GSM8K" , 3),
  #                       ], custom_vocab)

  # subset_tasks = sorted(all_subtasks + custom_tasks_names) # list(filter(lambda x: any(x.startswith(t) for t in selected_tasks), all_subtasks))
  subset_tasks = sorted(all_subtasks) # list(filter(lambda x: any(x.startswith(t) for t in selected_tasks), all_subtasks))

  print(len(subset_tasks))
  if not have_all:
    subset_tasks = subset_tasks[begin:end]

  # modify tasks to add filter_too_long preprocessor
  cached_task_names = modify_task_preprocessors(
    subset_tasks, 
    sequence_length=sequence_length, 
  )
  return cached_task_names


  # # add cached version to registry
  # cached_task_names = []
  # features = ['inputs', 'targets']
  # cur_seq_len = {feature:sequence_length // len(features) for feature in features}
  
  # for cur_task_name in subset_tasks:
  #     seqio.experimental.add_fully_cached_task(
  #         cur_task_name,
  #         sequence_length=cur_seq_len
  #     )
  #     cached_task_names.append(
  #       seqio.experimental._get_fully_cached_name(
  #         cur_task_name,
  #         sequence_length=cur_seq_len
  #       )
  #     )
  # print(cached_task_names)
  # return cached_task_names
