import jax
from jax import numpy as jnp

def make_preprocess_seqs(tokenizer, n_partitions, max_length, add_position_ids = False):
    """
    """
    def preprocess_seqs(batch):   
        seqs = dict(tokenizer(batch, return_tensors='np', padding = 'max_length', max_length = max_length, truncation = True)) # 2048
        if 'position_ids' not in seqs.keys():
            seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions
            _seqs = []
            for i in range(n_partitions): # n_partitions
                _seq = {}
                for key in seqs.keys():
                    _seq[key] = seqs[key][i*partition_size:(i+1)*partition_size]
                _seqs.append(_seq)
            return _seqs
        else:
            return seqs

    return preprocess_seqs


def make_apply_model(model):
    # @jax.jit
    def apply_bert(seq):
        output = model.module.apply({'params': model.params}, **seq, deterministic = True,
                             output_attentions = False,
                             output_hidden_states = True, 
                             return_dict = True)
        return output
    return apply_bert