R"""

Original paper:
    https://core.ac.uk/download/pdf/269313817.pdf
    https://www.nature.com/articles/s41587-021-01156-3?utm_campaign=related_content&utm_source=BIOENG&utm_medium=communities



To download and prepare:


from em.datasets.protein import signal_peptide
from em.util import vat_da_faak_vpn

for config in signal_peptide.SignalPeptide.BUILDER_CONFIGS:
    builder = signal_peptide.SignalPeptide(config=config)
    builder.download_and_prepare()


"""
import dataclasses
import itertools
import random

import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import PreTrainedTokenizer

from . import common as protein_common 

###############################################################################


SP6_SPLIT_TO_N_EXAMPLES = {
    'train': 20290,
    'validation': 8811,
}


# It looks like sequences are no more than 70 tokens long, so
# set the sequence length to 72 when running this.


###############################################################################

_TASK_NAMES = ('sp6', 'sp6_binary')


def load(
    task: str,
    split: str,
    tokenizer: PreTrainedTokenizer,
    sequence_length: int,
):
    if task not in _TASK_NAMES:
        raise ValueError(f'Invalid SignalPeptide task: {task}')

    ds = tfds.load("signal_peptide/sp6", split=split)

    ds = ds.map(lambda x: (x['aa_sequence'], x['label']))
    ds = ds.map(protein_common.get_supervised_encode_aa_sequence_tf_fn(tokenizer, sequence_length))

    if task == 'sp6_binary':
        ds = ds.map(_convert_to_binary_map_fn)

    return ds


def n_classes_for_task(task: str) -> int:
    if task == 'sp6':
        return 6
    elif task == 'sp6_binary':
        return 2
    else:
        raise ValueError(task)


def de_facto_validation_split(task):
    return 'validation'


def examples_per_epoch(task):
    return SP6_SPLIT_TO_N_EXAMPLES['train']


###############################################################################

def _convert_to_binary_map_fn(x, y):
    return x, tf.cast(y != 0, tf.int64)


###############################################################################
###############################################################################


_SP_URL = "https://services.healthtech.dtu.dk/service.php?SignalP"

_SP_CITATION = R"""@article{almagro2019signalp,
  title={SignalP 5.0 improves signal peptide predictions using deep neural networks},
  author={Almagro Armenteros, Jos{\'e} Juan and Tsirigos, Konstantinos D and S{\o}nderby, Casper Kaae and Petersen, Thomas Nordahl and Winther, Ole and Brunak, S{\o}ren and von Heijne, Gunnar and Nielsen, Henrik},
  journal={Nature biotechnology},
  volume={37},
  number={4},
  pages={420--423},
  year={2019},
  publisher={Nature Publishing Group US New York}
}
@article{teufel2022signalp,
  title={SignalP 6.0 predicts all five types of signal peptides using protein language models},
  author={Teufel, Felix and Almagro Armenteros, Jos{\'e} Juan and Johansen, Alexander Rosenberg and G{\'\i}slason, Magn{\'u}s Halld{\'o}r and Pihl, Silas Irby and Tsirigos, Konstantinos D and Winther, Ole and Brunak, S{\o}ren and von Heijne, Gunnar and Nielsen, Henrik},
  journal={Nature biotechnology},
  volume={40},
  number={7},
  pages={1023--1025},
  year={2022},
  publisher={Nature Publishing Group US New York}
}"""

###############################################################################

# The last two labels do not appear to be in the SP5 benchmark.
SP_LABELS = ('NO_SP', 'SP', 'TAT', 'LIPO', 'TATLIPO', 'PILIN')


@dataclasses.dataclass
class SpExample:
    uniprot_ac: str
    kingdom: str
    partition: int

    label: str

    aa_sequence: str
    annotation_sequence: str

    def __post_init__(self):
        assert self.label in SP_LABELS, self.label

    def to_tfds_dict(self):
        return {
            'uniprot_ac': self.uniprot_ac,
            'kingdom': self.kingdom,
            'partition': self.partition,
            'label': self.label,
            'aa_sequence': self.aa_sequence,
            'annotation_sequence': self.annotation_sequence,
        }


def parse_3_line_fasta(f) -> SpExample:
    header, aa_sequence, annotation_sequence = itertools.islice(f, 3)
    uniprot_ac, kingdom, label, partition = header.split('|')
    return SpExample(
        # Remove the initial ">" on the uniprot_ac
        uniprot_ac=uniprot_ac[1:],
        kingdom=kingdom,
        partition=int(partition),
        label=label,
        aa_sequence=aa_sequence.upper(),
        annotation_sequence=annotation_sequence.upper(),
    )

###############################################################################


_FAST_URL_PATTERN = 'https://services.healthtech.dtu.dk/services/SignalP-6.0/public_data/{}.fasta'

_SHUFFLE_SEED = 101010101


class SignalPeptide(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('1.0.0')

    BUILDER_CONFIGS = [
        tfds.core.BuilderConfig(
            name='sp6',
            description='TODO',
            version=VERSION,
        ),
    ]

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description="TODO",
            features=tfds.features.FeaturesDict({
                "uniprot_ac": tfds.features.Text(),
                "kingdom": tfds.features.Text(),
                "partition": tf.int64,
                "label": tfds.features.ClassLabel(names=SP_LABELS),
                "aa_sequence": tfds.features.Text(),
                "annotation_sequence": tfds.features.Text(),
            }),
            supervised_keys=None,
            homepage=_SP_URL,
            citation=_SP_CITATION,
        )

    def _split_generators(self, dl_manager):
        filepaths = dl_manager.download({
            'train': _FAST_URL_PATTERN.format('train_set'),
            'validation': _FAST_URL_PATTERN.format('benchmark_set_sp5'),
        })
        return [
            tfds.core.SplitGenerator(
                name=tfds.Split.TRAIN,
                gen_kwargs={
                    "filepath": filepaths['train']
                }),
            tfds.core.SplitGenerator(
                name=tfds.Split.VALIDATION,
                gen_kwargs={
                    "filepath": filepaths['validation']
                }),
        ]

    def _generate_examples(self, filepath):
        examples = []
        with open(filepath, 'rt') as f:
            while f:
                try:
                    examples.append(parse_3_line_fasta(f))
                except ValueError:
                    break

        # The entries in the fasta files are not in random order,
        # so we deterministically shuffle them.
        random.Random(_SHUFFLE_SEED).shuffle(examples)

        for ex in examples:
            yield ex.uniprot_ac, ex.to_tfds_dict()
