"""The model classes for protein-bert."""
import dataclasses
from typing import Optional

from proteinbert import conv_and_global_attention_model as og_pb_model

import tensorflow as tf


###############################################################################

DEFAULT_N_ANNOTATIONS = 8943

###############################################################################


@dataclasses.dataclass
class BodyConfig:
    # Pretrained body config.
    vocab_size: int
    d_hidden_seq: int
    d_hidden_global: int
    n_blocks: int
    n_heads: int
    d_key: int
    conv_kernel_size: int
    wide_conv_dilation_rate: int
    activation: str

    def to_dict(self):
        return dataclasses.asdict(self)

    @classmethod
    def from_dict(cls, dct):
        return cls(**dct)


@dataclasses.dataclass
class ClassifierConfig:
    # Classifier-specific parameters.
    n_classes: int
    # Misc/General parameters.
    dropout_rate: float

    def to_dict(self):
        return dataclasses.asdict(self)

    @classmethod
    def from_dict(cls, dct):
        return cls(**dct)


@dataclasses.dataclass
class ProteinBertConfig:
    # TODO: Make this more like HuggingFace's.
    
    body: BodyConfig

    classifier: Optional[ClassifierConfig] = None

    # The following are more like "run-time" configs that
    # do not need to be saved to a checkpoint and can change
    # upon reloading.
    sequence_length: Optional[int] = None

###############################################################################


def _create_og_body(config: ProteinBertConfig) -> tf.keras.Model:
    # input: keras.layers.Input(shape = (seq_len,), dtype = np.int32, name = 'input-seq')
    # output: pretraining_output_seq_layer, pretraining_output_annoatations_layer
    sequence_length = config.sequence_length
    body = config.body
    model = og_pb_model.create_model(
        seq_len=sequence_length,
        vocab_size=body.vocab_size,
        n_annotations=DEFAULT_N_ANNOTATIONS,
        d_hidden_seq=body.d_hidden_seq,
        d_hidden_global=body.d_hidden_global,
        n_blocks=body.n_blocks,
        n_heads=body.n_heads,
        d_key=body.d_key,
        conv_kernel_size=body.conv_kernel_size,
        wide_conv_dilation_rate=body.wide_conv_dilation_rate,
        activation=body.activation,
    )
    return model


class ClassifierHead(tf.keras.Model):
    def __init__(self, config: ProteinBertConfig, **kwargs):
        super().__init__(**kwargs)
        self._config = config
        self.dropout = tf.keras.layers.Dropout(self._config.classifier.dropout_rate)
        self.dense = tf.keras.layers.Dense(self._config.classifier.n_classes, activation=None)

    def call(self, x, training=None):
        x = self.dropout(x, training=training)
        return self.dense(x, training=training)


class ProteinBertForSequenceClassification(tf.keras.Model):
    """Protein-BERT for sequence classification."""

    def __init__(self, config: ProteinBertConfig, **kwargs):
        super().__init__(**kwargs)
        self._config = config
        self._n_annotations = DEFAULT_N_ANNOTATIONS

        self._og_body = _create_og_body(config)
        self._body = self._make_body()

        self._classifier = ClassifierHead(config)

    def _make_body(self):
        # The original implementation appears to run a classifier head on the sigmoided
        # output, which does not seem good. This ends up feeding to input of that layer
        # to our classification head, which is more similar to how BERT and everyone
        # does it.
        output_annotation_layer, = [layer for layer in self._og_body.layers if layer.name == 'output-annotations']
        return tf.keras.models.Model(self._og_body.inputs, outputs=output_annotation_layer.input)

    def _make_dummy_annotations(self, sequences: tf.Tensor) -> tf.Tensor:
        shape = tf.concat([tf.shape(sequences)[:-1], [self._n_annotations]], axis=0)
        return tf.zeros(shape, dtype=tf.float32)

    def call(self, x: tf.Tensor, training=None):
        # Return the logits.
        output = self._body([x, self._make_dummy_annotations(x)], training=training)
        logits = self._classifier(output, training=training)
        return logits
