"""Context for analysis of NPEFF for signal peptide data set."""
import dataclasses
import os
from typing import List

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import AutoTokenizer, PreTrainedTokenizer

from em import datasets as em_datasets
from em.datasets.protein import signal_peptide
from em.fishers import per_example
from em.tools.nmf import nmf_common

SP_LABELS = signal_peptide.SP_LABELS

##########################################################################


def _decode_aa_sequence(tokenizer: PreTrainedTokenizer, token_ids: np.ndarray) -> str:
    s = tokenizer.decode(token_ids)
    s = s.replace(tokenizer.cls_token, '')
    s = s.replace(tokenizer.sep_token, '')
    s = s.replace(tokenizer.pad_token, '')
    s = s.replace(' ', '')
    return s

##########################################################################


@dataclasses.dataclass
class SpExample:
    index: int

    uniprot_ac: str
    kingdom: str
    partition: int
    #
    aa_sequence: str
    annotation_sequence: str

    label: int
    logits: np.ndarray

    @property
    def binary_label(self) -> int:
        return int(self.label != 0)

    @property
    def prediction(self) -> int:
        return np.argmax(self.logits)

    def get_label_as_str(self) -> str:
        return SP_LABELS[self.label]

    def is_task_binarized(self) -> bool:
        return len(self.logits) == 2


##########################################################################


@dataclasses.dataclass
class NpeffContext:
    # Assumes the examples are the first in the dataset split.

    pef: per_example.PerExampleFlatFishers
    nmf: nmf_common.SparseNmfDecomposition
    tokenizer: PreTrainedTokenizer

    split: str

    def __post_init__(self):
        self.n_components = self.nmf.W.shape[1]
        self.n_examples = self.nmf.W.shape[0]
        self.examples = self._create_examples()

    def _create_examples(self) -> List[SpExample]:
        ds = tfds.load("signal_peptide/sp6", split=self.split)
        ds = ds.take(self.n_examples)

        examples = []
        for i, ex in enumerate(ds.as_numpy_iterator()):
            # Use this to get exactly what the model sees. I think the only real
            # preprocessing is replacing ambiguous or non-canonical AAs with the
            # unknown AA character code.
            aa_sequence = _decode_aa_sequence(self.tokenizer, self.pef.input_ids[i])
            examples.append(
                SpExample(
                    index=i,
                    #
                    uniprot_ac=tf.compat.as_str(ex['uniprot_ac']),
                    kingdom=tf.compat.as_str(ex['kingdom']),
                    partition=ex['partition'],
                    #
                    aa_sequence=aa_sequence,
                    annotation_sequence=tf.compat.as_str(ex['annotation_sequence']),
                    #
                    label=ex['label'],
                    logits=self.pef.predicted_logits[i],
                ))

        return examples

    def get_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        coeffs = self.nmf.W[:, component_index]
        return np.argsort(-coeffs)[:n_examples]

    def get_top_examples(self, component_index: int, n_examples: int) -> List[SpExample]:
        return [self.examples[i] for i in self.get_top_example_indices(component_index, n_examples)]

    @classmethod
    def load(cls, pef_path: str, nmf_path: str, tokenizer: str, **kwargs):
        pef = per_example.PerExampleFlatFishers.load(
            pef_path,
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        nmf = nmf_common.SparseNmfDecomposition.load(nmf_path)
        nmf.normalize_components_to_unit_norm()

        # TODO: Allow us to specify othe subsets of example indices.
        if pef.input_ids.shape[0] > nmf.W.shape[0]:
            pef = pef.create_for_subset(list(range(nmf.W.shape[0])))

        tokenizer = AutoTokenizer.from_pretrained(tokenizer)

        return cls(pef=pef, nmf=nmf, tokenizer=tokenizer, **kwargs)
