from jax import numpy as jnp

from ProtLig_GPCRclassA.utils import serialize_ESM2_hidden_states

def make_preprocess_seqs(tokenizer, n_partitions, max_length, add_position_ids = False):
    """
    """
    def preprocess_seqs(batch):   
        seqs = dict(tokenizer(batch, return_tensors='pt', padding = 'max_length', max_length = max_length, truncation = True)) # 2048
        if add_position_ids and 'position_ids' not in seqs.keys():
            raise NotImplementedError('...')
            seqs['position_ids'] = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seqs['input_ids']).shape[-1]), seqs['input_ids'].shape)
        if n_partitions > 0:
            partition_size = len(batch) // n_partitions
            _seqs = []
            for i in range(n_partitions): # n_partitions
                _seq = {}
                for key in seqs.keys():
                    _seq[key] = seqs[key][i*partition_size:(i+1)*partition_size]
                _seqs.append(_seq)
            return _seqs
        else:
            return seqs

    return preprocess_seqs


def make_apply_model(model):
        def apply_model(seq):
            output = model(**seq, 
                                output_attentions = False,
                                output_hidden_states = True, 
                                return_dict = True)
            return output
        return apply_model



# --------------- --------------- --------------- --------------- ---------------
# make_apply_seqs_model:
# --------------- --------------- --------------- --------------- ---------------
def make_apply_seqs_model(seq_model, tokenizer, max_length, n_partitions = 0):
    preprocess_seqs = make_preprocess_seqs(tokenizer, 
                                        n_partitions = n_partitions, 
                                        max_length = max_length, 
                                        add_position_ids = False)
    apply_model = make_apply_model(model = seq_model)
    def apply_seqs_model(seqs):
        _seqs = preprocess_seqs(seqs)
        output = apply_model(_seqs)

        attn_mask = _seqs['attention_mask']
        hidden_states = serialize_ESM2_hidden_states(output.hidden_states)

        seqs_hidden_states = []
        seqs_attn = []
        for j in range(len(seqs)):
            x = hidden_states[j][-1,:,:].astype(jnp.float32)
            attn = attn_mask[j].detach().numpy().astype(jnp.int32)
            seqs_hidden_states.append(x)
            seqs_attn.append(attn)

        if n_partitions > 0:
            raise NotImplementedError('needs to be checked...')
            seqs = []
            partition_size = len(batch) // n_partitions
            for i in range(n_partitions):
                seqs.append(jnp.stack(seqs_hidden_states[i*partition_size:(i+1)*partition_size]))
        else:
            seqs = (jnp.stack(seqs_hidden_states), jnp.stack(seqs_attn))
        return seqs

    return apply_seqs_model
# --------------- --------------- --------------- --------------- ---------------