"""Common stuff amongst the protein datasets."""
import tensorflow as tf
from transformers import PreTrainedTokenizer

# # Use data at a particular github commit for consistency.
# PROTEIN_BERT_CSVS_DIR = "https://raw.githubusercontent.com/nadavbra/protein_bert/2090b78f09e70f0b960d7b591bf80d3d411c9d3f/protein_benchmarks"


def preprocess_aa_sequence_tf(aa_seq: tf.Tensor):
    # Preprocesses an aa sequence for use with the ProtBert tokenizer.
    aa_seq = tf.strings.upper(aa_seq)
    aa_seq = tf.strings.regex_replace(aa_seq, r"[UZOB]", "X")
    aa_seq = tf.strings.regex_replace(aa_seq, r'(\w)', r'\1 ')
    aa_seq = tf.strings.strip(aa_seq)
    return aa_seq


def encode_aa_sequence_tf(tokenizer: PreTrainedTokenizer, sequence_length: int, aa_seq: tf.Tensor):
    def py_fn(aa_seq: str):
        aa_seq = tf.compat.as_str(aa_seq.numpy())
        inputs = tokenizer.encode_plus(
            aa_seq,
            add_special_tokens=True,
            max_length=sequence_length,
            return_token_type_ids=True,
            padding='max_length',
            truncation=True,
            return_tensors='tf',
        )
        input_ids = inputs["input_ids"]
        token_type_ids = inputs["token_type_ids"]
        attention_mask = inputs['attention_mask']
        return input_ids, token_type_ids, attention_mask

    aa_seq = preprocess_aa_sequence_tf(aa_seq)
    input_ids, token_type_ids, attention_mask = tf.py_function(
        func=py_fn,
        inp=[aa_seq],
        Tout=[tf.int32, tf.int32, tf.int32],
    )
    return {
        # Ensure the shape is known as this is often needed for downstream steps.
        "input_ids": tf.reshape(input_ids, [sequence_length]),
        "token_type_ids": tf.reshape(token_type_ids, [sequence_length]),
        "attention_mask": tf.reshape(attention_mask, [sequence_length]),
    }


def get_supervised_encode_aa_sequence_tf_fn(tokenizer: PreTrainedTokenizer, sequence_length: int):
    def map_fn(aa_seq, label):
        inputs = encode_aa_sequence_tf(tokenizer, sequence_length, aa_seq)
        return inputs, label
    return map_fn
