# A BERT model that
# - has embedding projector when embedding_size != hiddne_size, like ELECTRA
# - the attention use one linear projection to generate query, key, value at once to get faster
# - is able to choose rotary position embedding

from copy import deepcopy
import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from _modeling.configuration_tsp import TSPConfig


class TSPPreTrainedModel(PreTrainedModel):
    config_class = TSPConfig
    base_model_prefix = "backbone"

    def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


# ====================================
# Pretraining Model
# ====================================


class TSPModelForPreTraining(TSPPreTrainedModel):
    def __init__(self, config, num_classes=None):
        super().__init__(config)
        self.backbone = TSPModel(config)
        if config.use_electra:
            generator_config = deepcopy(config)
            generator_config.hidden_size //= config.electra_generator_size_divisor
            generator_config.intermediate_size //= config.electra_generator_size_divisor
            generator_config.num_attention_heads //= (
                config.electra_generator_size_divisor
            )
            self.generator_backbone = TSPModel(generator_config)
            self.mlm_head = MaskedLMHead(
                generator_config,
                word_embeddings=self.mlm_backbone.embeddings.word_embeddings,
            )
            self.rtd_backbone.embeddings = self.mlm_backbone.embeddings
            self.rtd_head = ReplacedTokenDiscriminationHead(config)
        else:
            self.mlm_head = MaskedLMHead(
                config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
            )
        self.tsp_head = TextStructurePredictionHead(config, num_classes=num_classes)
        self.apply(self._init_weights)

        # Other attributes
        self.pad_token_id = config.pad_token_id
        self.sep_token_id = config.sep_token_id
        self.tsp_classes = {
            "same_doc_and_reverse": 0,
            "same_para_and_reverse": 1,
            "neighbor_and_reverse": 2,
            "neighbor_and_forward": 3,
            "same_para_and_forward": 4,
            "same_doc_and_forward": 5,
        }

    @property
    def mlm_backbone(self):
        if self.config.use_electra:
            return self.generator_backbone
        else:
            return self.backbone

    @property
    def rtd_backbone(self):
        assert self.config.use_electra
        return self.backbone

    def forward(
        self,
        masked_ids,  # <int>(B,L)
        segment_ids,  # <int>(B,L)
        mlm_labels,  # <int>(B,L), [original token ids / -100] at mlm [selected / non-selected] positions
        sentence_marks,  # <int>(B, L), index of sentence the token belongs to against unshuffled sentences
        paragraph_ids,  # <int>(B,S), id of pargraph belongs to in order of unshuffled sentences
        permutation=None,  # <int>(B,L), token permutation to shuffle sentences
    ):
        loss = 0

        (
            mlm_loss,
            mlm_logits,  # (#mlm_selected, vocab size)
            mlm_original_ids,  # <int>(#mlm_selected)
            hidden_states,  # (B,L,D)
            mlm_selected,  # <bool>(B,L)
        ) = self.mlm_phase(
            masked_ids=masked_ids, segment_ids=segment_ids, mlm_labels=mlm_labels
        )
        loss += mlm_loss

        if self.config.use_electra:
            replaced_ids, rtd_labels = self.sampling_phase(
                mlm_logits=mlm_logits,
                masked_ids=masked_ids,
                mlm_original_ids=mlm_original_ids,
                mlm_selected=mlm_selected,
            )

            if permutation is not None:  # Apply sentence shuffling
                replaced_ids = replaced_ids.gather(1, permutation)
                rtd_labels = rtd_labels.gather(1, permutation)
                is_sep = replaced_ids == self.sep_token_id
                segment_ids = (is_sep.cumsum(dim=1) - is_sep.long()).bool().long()

            rtd_loss, rtd_logits, hidden_states = self.rtd_phase(
                replaced_ids=replaced_ids,
                segment_ids=segment_ids,
                rtd_labels=rtd_labels,
            )
            loss += rtd_loss * 50

        tsp_loss, tsp_logits, tsp_labels = self.text_structure_prediction(
            sentence_embeddings=self.get_sentence_embeddings(
                hidden_states=hidden_states,
                sentence_marks=sentence_marks,
                max_num_sentences=paragraph_ids.shape[1],
            ),
            paragraph_ids=paragraph_ids,
        )
        loss += tsp_loss * (2 if self.config.use_electra else 1)

        return loss

    def mlm_phase(
        self,
        masked_ids,  # <int>(B,L)
        segment_ids,  # <int>(B,L)
        mlm_labels,  # <int>(B,L), [original token ids / -100] at mlm [selected / non-selected] positions
    ):
        # Backbone Forward
        hidden_states = self.mlm_backbone(
            input_ids=masked_ids,
            attention_mask=(masked_ids != self.pad_token_id).long(),
            token_type_ids=segment_ids,
        )  # (B,L,D)

        # MLM Loss
        mlm_selected = mlm_labels != -100  # <bool>(B,L)
        mlm_logits = self.mlm_head(
            hidden_states, is_selected=mlm_selected
        )  # (#mlm selected, vocab size)
        mlm_original_ids = mlm_labels[mlm_selected]  # <int>(#mlm_selected)
        mlm_loss = F.cross_entropy(mlm_logits, mlm_original_ids)

        return mlm_loss, mlm_logits, mlm_original_ids, hidden_states, mlm_selected

    @torch.no_grad()  # Note gradient flow stops here naturally, adding no_grad here is just to strengthen it and get a little bit faster.
    def sampling_phase(
        self,
        mlm_logits,  # (#mlm_selected, V)
        masked_ids,  # <int>(B,L)
        mlm_original_ids,  # <int>(#mlm_selected)
        mlm_selected,  # <bool>(B,L)
    ):
        # RTD input ids
        sampling_logits = self.add_gumbel_noise(mlm_logits)
        sampled_ids = sampling_logits.argmax(dim=1)  # <int>(#mlm selected)
        replaced_ids = masked_ids.masked_scatter(mlm_selected, sampled_ids)
        # RTD labels
        is_replaced = sampled_ids != mlm_original_ids
        rtd_labels = torch.zeros_like(replaced_ids)  # <int>(B,L)
        rtd_labels[mlm_selected] = is_replaced.long()

        return replaced_ids, rtd_labels

    def rtd_phase(
        self,
        replaced_ids,  # <int>(B,L)
        segment_ids,  # <int>(B,L)
        rtd_labels,  # <int>(B,L), [0 / 1] for [non-replaced / replaced] tokens
    ):
        # Backbone Forward
        non_pad = replaced_ids != self.pad_token_id
        hidden_states = self.rtd_backbone(
            input_ids=replaced_ids,
            attention_mask=non_pad.long(),
            token_type_ids=segment_ids,
        )  # (B,L,D)

        # RTD Loss
        logits = self.rtd_head(hidden_states).squeeze(-1)  # (B,L)
        logits = logits[non_pad]  # (#non pad)
        rtd_labels = rtd_labels[non_pad]  # <int>(#non pad)
        rtd_loss = F.binary_cross_entropy_with_logits(logits, rtd_labels.float())

        return rtd_loss, logits, hidden_states

    def get_sentence_embeddings(
        self,
        hidden_states,  # (B,L,D)
        sentence_marks,  # <int>(B, L), index of sentence the token belongs to against unshuffled sentences
        max_num_sentences: int,
    ):
        B, L, S = *sentence_marks.shape, max_num_sentences
        dtype, device = hidden_states.dtype, hidden_states.device

        hidden_states = F.dropout(
            hidden_states, self.config.dropout_prob, self.training
        )

        mask = torch.arange(S, device=device)
        mask = mask.view(1, -1, 1).expand(B, -1, L)  # <int>(B,#sent,L)
        mask = mask == sentence_marks.view(B, 1, L)  # <bool>(B,#sent,L)
        sentence_lengths = mask.sum(dim=-1, keepdim=True)  # <int>(B,#sent,1)
        sentence_lengths = sentence_lengths.clip(min=1)  # avoid zero division
        sent_embeds = torch.bmm(mask.to(dtype=dtype), hidden_states)  # (B,#sent,D)
        sent_embeds /= sentence_lengths

        sent_embeds = F.dropout(sent_embeds, self.config.dropout_prob, self.training)

        return sent_embeds  # (B,#sent,D)

    def text_structure_prediction(
        self,
        sentence_embeddings,  # (B,L,D)
        paragraph_ids,  # <int>(B, S), id of pargraph which sentence belongs to in order of unshuffled sentences
    ):
        B, S, device = *paragraph_ids.shape, paragraph_ids.device
        para_ids = paragraph_ids  # <int>(B, S)
        padding = para_ids == -1  # <bool>(B, S)

        # Get relations of sentence pairs at different level of hierarchy
        postitions = torch.arange(S, device=device).view(1, S)
        diff = postitions.view(-1, 1, S) - postitions.view(-1, S, 1)  # <int>(B,S,S)
        dist = diff.abs()  # <int>(B,S,S)
        same_para = para_ids.view(B, 1, S) == para_ids.view(B, S, 1)  # <bool>(B,S,S)
        neighbor = dist == 1
        reverse, forward = (diff < 0).expand(B, S, S), (diff > 0).expand(B, S, S)

        # Labeling
        labels = torch.full_like(same_para, -1, dtype=torch.long)
        labels[forward] = self.tsp_classes["same_doc_and_forward"]
        labels[reverse] = self.tsp_classes["same_doc_and_reverse"]
        labels[same_para & forward] = self.tsp_classes["same_para_and_forward"]
        labels[same_para & reverse] = self.tsp_classes["same_para_and_reverse"]
        labels[neighbor & forward] = self.tsp_classes["neighbor_and_forward"]
        labels[neighbor & reverse] = self.tsp_classes["neighbor_and_reverse"]

        # Ignoring
        ignoring = padding.view(B, 1, S) | padding.view(B, S, 1)
        labels.masked_fill_(ignoring, -1)

        return self._text_structure_prediction(
            sentence_embeddings=sentence_embeddings, tsp_labels=labels,
        )

    def _text_structure_prediction(
        self,
        sentence_embeddings,  # (B,S,D)
        tsp_labels,  # <int>(B,S,S), lable < 0 at ignored/padding position
    ):
        B, S, D, sent_embeds = *sentence_embeddings.shape, sentence_embeddings

        learning = tsp_labels >= 0  # <bool>(B,S,S)
        a_embeds = sent_embeds.view(B, S, 1, D).expand(B, S, S, D)[learning]
        b_embeds = sent_embeds.view(B, 1, S, D).expand(B, S, S, D)[learning]
        sent_pair_embeds = torch.cat([a_embeds, b_embeds], dim=-1)  # (#learn, 2D)
        sent_pair_logits = self.tsp_head(sent_pair_embeds)  # (#learn, #classes)

        tsp_labels = tsp_labels[learning]  # <int>(#learn)
        tsp_loss = F.cross_entropy(sent_pair_logits, tsp_labels)

        return tsp_loss, sent_pair_logits, tsp_labels

    def add_gumbel_noise(self, tensor):
        # We have to create distribution object on the fly to make sure its dtype and device is adequate
        # b/c dtype or device setting to model parameters won't affect torch distribution
        # see https://github.com/pytorch/pytorch/issues/7795
        if not hasattr(self, "gumbel_distribution"):
            self.gumbel_distribution = torch.distributions.gumbel.Gumbel(
                loc=torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype),
                scale=torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
            )
        with torch.cuda.amp.autocast(enabled=False):
            # temporarily disable autocast to avoid uneeded forced fp32 sampling
            noise = self.gumbel_distribution.sample(tensor.shape)
        return tensor + noise


class MaskedLMHead(nn.Module):
    def __init__(self, config, word_embeddings=None):
        super().__init__()
        self.linear = nn.Linear(config.hidden_size, config.embedding_size)
        self.norm = nn.LayerNorm(config.embedding_size)
        self.predictor = nn.Linear(config.embedding_size, config.vocab_size)
        if word_embeddings is not None:
            self.predictor.weight = word_embeddings.weight

    def forward(
        self,
        x,  # (B,L,D)
        is_selected=None,  # <bool>(B,L), True at positions choosed by mlm probability
    ):
        if is_selected is not None:
            # Only mlm positions are counted in loss, so we can apply output layer computation only to
            # those positions to significantly reduce compuatational cost
            x = x[is_selected]  # ( #selected, D)
        x = self.linear(x)  # (B,L,E)/(#selected,E)
        x = F.gelu(x)  # (B,L,E)/(#selected,E)
        x = self.norm(x)  # (B,L,E)/(#selected,E)
        return self.predictor(x)  # (B,L,V)/(#selected,V)


class ReplacedTokenDiscriminationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear = nn.Linear(config.hidden_size, config.hidden_size)
        self.predictor = nn.Linear(config.hidden_size, 1)

    def forward(self, x):  # (B,L,D)
        x = self.linear(x)  # (B,L,D)
        x = F.gelu(x)
        x = self.predictor(x)  # (B,L,1)
        return x.squeeze(-1)  # (B,L)


class TextStructurePredictionHead(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
        self.norm = nn.LayerNorm(config.hidden_size * 2)
        self.linear2 = nn.Linear(config.hidden_size * 2, num_classes)

    def forward(
        self, x,  # (...,2D)
    ):
        x = self.linear1(x)  # (...,2D)
        x = F.gelu(x)  # (...,2D)
        x = self.norm(x)  # (...,2D)
        return self.linear2(x)  # (...,C)


# ====================================
# Finetuning Model
# ====================================


class TSPModelForFinetuning(TSPPreTrainedModel):
    def __init__(self, config, head_cls, **kwargs):
        super().__init__(config)
        self.backbone = TSPModel(config)
        self.head = head_cls(config, **kwargs)
        self.apply(self._init_weights)

    def forward(
        self,
        input_ids,  # <int>(B,L)
        attention_mask,  # <int>(B,L), 1 / 0 for tokens that are attended/ not attended
        token_type_ids,  # <int>(B,L), 0 / 1 corresponds to a segment A / B token
        *args,
        **kwargs,
    ):
        hidden_states = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )  # (B,L,D)
        return self.head(hidden_states, *args, **kwargs)


class TSPModelForTokenClassification(TSPModelForFinetuning):
    def __init__(self, config, num_classes):
        super().__init__(
            config, head_cls=TokenClassificationHead, num_classes=num_classes
        )


class TokenClassificationHead(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_classes)

    def forward(self, x):  # (B,L,D)
        x = self.dropout(x)  # (B,L,D)
        x = self.classifier(x)  # (B,L,C)
        return x  # (B,L,C)


class TSPModelForSpanComparison(TSPModelForFinetuning):
    def __init__(self, config, num_classes):
        super().__init__(config, head_cls=SpanComparisonHead, num_classes=num_classes)


class SpanComparisonHead(nn.Module):
    def __init__(self, config, num_classes, num_spans=2):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout_prob)
        self.classifier = nn.Linear((1 + num_spans) * config.hidden_size, num_classes)

    def forward(
        self,
        x,  # (B,L,D)
        span_ranges,  # <int>(B, #spans, 2), start and end (not included) token indexes for spans
    ):
        # span embed. take mean pooling over tokens associated with the span
        B, L, D, device = *x.shape, x.device
        span_start, span_end = span_ranges[:, :, :1], span_ranges[:, :, 1:]
        span_length = span_end - span_start  # <int>(B, #spans, 1)
        member_mask = torch.arange(L, device=device)
        member_mask = member_mask.view(1, 1, L).expand(B, 1, L)  # <int>(B,1,L)
        member_mask = (span_start <= member_mask) & (member_mask < span_end)
        member_mask = member_mask.to(dtype=x.dtype)  # (B, #spans,L), 1 or 0
        span_embeds = torch.bmm(member_mask, x)  # (B,#spans,D)
        span_embeds /= span_length  # (B,#spans,D)

        # feed to classifier
        cls_embed = x[:, :1, :]  # (B,1,D)
        pair_embed = torch.cat([cls_embed, span_embeds], dim=1)  # (B, 1+#spans, D)
        pair_embed = self.dropout(pair_embed.view(B, -1))  # (B, (1+#spans)*D)
        logits = self.classifier(pair_embed)  # (B,C)
        return logits


class TSPModelForReCoRD(TSPModelForFinetuning):
    def __init__(self, config, num_classes=1):
        super().__init__(config, head_cls=ReCoRDHead, num_classes=num_classes)


class ReCoRDHead(nn.Module):
    def __init__(self, config, num_classes=1):
        assert num_classes == 1
        super().__init__()
        self.proj = nn.Linear(2 * config.hidden_size, 2 * config.hidden_size)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.classifier = nn.Linear(2 * config.hidden_size, 1)

    def forward(
        self,
        x,  # (B,L,D)
        masked_idx,  # <int>(B), token index of @placeholder (replaced by [mask])
        entity_ranges,  # <int>(B, #entities, 2), start and end token indices of entities
    ):
        B, L, D, device = *x.shape, x.device

        # Entity embeddings
        positions = torch.arange(L, device=device).view(1, 1, L).expand(B, 1, L)
        entity_starts, entity_ends = entity_ranges[:, :, :1], entity_ranges[:, :, 1:]
        entity_mask = entity_starts <= positions  # <bool>(B,#entities,L)
        entity_mask &= positions < entity_ends  # <bool>(B,#entities,L)
        entity_embeds = torch.bmm(entity_mask.to(x), x)  # (B,#entities,D)
        entity_length = entity_ends - entity_starts  # <int>(B, #entities, 1)
        entity_embeds /= entity_length  # (B,#entities,D)

        # Query embedding
        query_embed = x[torch.arange(B, device=device), masked_idx, :]  # (B,D)
        query_embeds = query_embed.view(B, 1, D).expand_as(
            entity_embeds
        )  # (B,#entities,D)

        # Classify
        pair_embeds = torch.cat(
            [query_embeds, entity_embeds], dim=2
        )  # (B,#entities,2D)
        pair_embeds = self.proj(pair_embeds)
        pair_embeds = F.gelu(pair_embeds)
        pair_embeds = self.dropout(pair_embeds)  # (B,#entities,2D)
        return self.classifier(pair_embeds)  # (B,#entities,1)


class TSPModelForTextPair(TSPPreTrainedModel):
    def __init__(self, config, num_classes):
        super().__init__(config)
        self.backbone = TSPModel(config)
        self.tsp_head = TSPTextPairHead(config, num_classes=num_classes)
        self.apply(self._init_weights)

    def forward(
        self,
        input_ids,  # <int>(B,L)
        attention_mask,  # <int>(B,L), 1 / 0 for tokens that are attended/ not attended
        token_type_ids,  # <int>(B,L), 0 / 1 corresponds to a segment A / B token
        # For Span comparision task
        span_ranges=None,  # <int>(B, #spans, 2), start and end (not included) token indexes for spans
        # For ReCoRD
        masked_idx=None,  # <int>(B), token index of @placeholder (replaced by [mask])
        entity_ranges=None,  # <int>(B, #entities, 2), start and end char indices of entities
    ):
        hidden_states = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )  # (B,L,D)
        B, L, D = hidden_states.shape
        if span_ranges is not None:
            pair_embeds = self.get_span_embeddings(hidden_states, span_ranges)
            pair_embeds = pair_embeds.view(B, -1)  # (B, 2D)
        elif entity_ranges is not None:
            entity_embeds = self.get_span_embeddings(hidden_states, entity_ranges)
            query_embed = self.get_token_embeddings(hidden_states, masked_idx)  # (B,D)
            query_embeds = query_embed.view(B, 1, D).expand_as(
                entity_embeds
            )  # (B,#entities,D)
            pair_embeds = torch.cat(
                [query_embeds, entity_embeds], dim=2
            )  # (B,#entities,2D)
        else:
            sentence_marks = token_type_ids + (attention_mask - 1)  # <int>(B,L)
            pair_embeds = TSPModelForPreTraining.get_sentence_embeddings(
                self,
                hidden_states=hidden_states,
                sentence_marks=sentence_marks,
                # <int>(B,L), 0 / 1 / -1 for first sentence /second sentence/ padding
                max_num_sentences=2,
            )  # (B,2,D)
            pair_embeds = pair_embeds.view(B, -1)  # (B, 2D)
        return self.tsp_head(pair_embeds)  # (B, #classes)

    def get_span_embeddings(self, x, span_ranges):
        x = F.dropout(x, p=self.config.dropout_prob, training=self.training)
        B, L, D, device = *x.shape, x.device
        span_start, span_end = span_ranges[:, :, :1], span_ranges[:, :, 1:]
        span_length = span_end - span_start  # <int>(B, #spans, 1)
        positions = torch.arange(L, device=device).view(1, 1, L).expand(B, 1, L)
        member_mask = (span_start <= positions) & (positions < span_end)
        span_embeds = torch.bmm(member_mask.to(x), x)  # (B,#spans,D)
        span_embeds /= span_length  # (B,#spans,D)
        span_embeds = F.dropout(
            span_embeds, p=self.config.dropout_prob, training=self.training
        )
        return span_embeds  # (B,#spans,D)

    def get_token_embeddings(self, x, token_idxs):
        B, L, D, device = *x.shape, x.device
        token_embed = x[torch.arange(B, device=device), token_idxs, :]  # (B,D)
        token_embed = F.dropout(
            token_embed, p=self.config.dropout_prob, training=self.training
        )
        return token_embed


class TSPTextPairHead(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
        self.norm = nn.LayerNorm(config.hidden_size * 2)
        self.classifier = nn.Linear(config.hidden_size * 2, num_classes)

    def forward(
        self, x,  # (...,2D)
    ):
        x = self.linear1(x)  # (...,2D)
        x = F.gelu(x)  # (...,2D)
        x = self.norm(x)  # (...,2D)
        return self.classifier(x)  # (...,C)


class TSPModelForSequenceClassification(TSPModelForFinetuning):
    def __init__(self, config, num_classes):
        super().__init__(
            config, head_cls=SequenceClassififcationHead, num_classes=num_classes
        )


class SequenceClassififcationHead(nn.Module):
    def __init__(self, config, num_classes):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_classes)

    def forward(
        self, x,  # (B,L,D)
    ):
        x = x[:, 0, :]  # (B,D), CLS token is taken
        x = self.dropout(x)  # (B,D)
        return self.classifier(x)  # (B,C)


class TSPModelForQuestionAnswering(TSPModelForFinetuning):
    def __init__(self, config, beam_size, predict_answerability):
        super().__init__(
            config,
            head_cls=SquadHead,
            beam_size=beam_size,
            predict_answerability=predict_answerability,
        )

    def forward(
        self,
        input_ids,  # <int>(B,L)
        attention_mask,  # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
        token_type_ids,  # <int>(B,L), 0 / 1 corresponds to a segment A / B token
        answer_start_position=None,  # train/eval: <int>(B)/None
    ):
        hidden_states = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )  # (B,L,D)
        return self.head(
            hidden_states,
            token_type_ids=token_type_ids,
            answer_start_position=answer_start_position,
        )


class SquadHead(nn.Module):
    def __init__(
        self, config, beam_size, predict_answerability,
    ):
        super().__init__()
        self.beam_size = beam_size
        self.predict_answerability = predict_answerability

        # answer start position predictor
        self.start_predictor = nn.Linear(config.hidden_size, 1)

        # answer end position predictor
        self.end_predictor = nn.Sequential(
            nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1),
        )

        # answerability_predictor
        if predict_answerability:
            self.answerability_predictor = nn.Sequential(
                nn.Linear(config.hidden_size * 2, 512), nn.GELU(), nn.Linear(512, 1),
            )
        else:
            self.answerability_predictor = None

    def forward(
        self,
        hidden_states,  # (B,L,D)
        token_type_ids,  # <int>(B,L), 0/1 for first sentence (question) or pad, 1 for second sentence (context)
        answer_start_position=None,  # train/eval: <int>(B)/None
    ):

        # Possible range for answer. Note CLS token is also possible to say it is unanswerable
        answer_mask = token_type_ids  # (B,L)
        last_sep = answer_mask.cumsum(dim=1) == answer_mask.sum(
            dim=1, keepdim=True
        )  # (B,L), True if it is the last SEP or token after it
        answer_mask = answer_mask * ~last_sep
        answer_mask[:, 0] = 1
        answer_mask = answer_mask.bool()

        # preidct start positions
        start_logits, start_top_hidden_states = self._calculate_start(
            hidden_states, answer_mask, answer_start_position
        )  # (B,L) , None/ (B,1,D)/ (B,k,D)

        # predict end positions
        end_logits = self._calculate_end_logits(
            hidden_states, start_top_hidden_states, answer_mask,
        )  # (B,L) / (B,k,L)

        # (optional) preidct answerability
        answerability_logits = None
        if self.answerability_predictor is not None:
            answerability_logits = self._calculate_answerability_logits(
                hidden_states, start_logits
            )  # (B)

        return start_logits, end_logits, answerability_logits

    def _calculate_start(self, hidden_states, answer_mask, start_positions):
        start_logits = self.start_predictor(hidden_states).squeeze(-1)  # (B, L)
        start_logits = start_logits.masked_fill(~answer_mask, -float("inf"))  # (B,L)
        start_top_indices, start_top_hidden_states = None, None
        if self.training:
            start_top_indices = start_positions  # (B,)
        else:
            k = self.beam_size
            _, start_top_indices = start_logits.topk(k=k, dim=-1)  # (B,k)
        start_top_hidden_states = torch.stack(
            [
                hiddens.index_select(dim=0, index=index)
                for hiddens, index in zip(hidden_states, start_top_indices)
            ]
        )  # train: (B,1,D)/ eval: (B,k,D)
        return start_logits, start_top_hidden_states

    def _calculate_end_logits(
        self, hidden_states, start_top_hidden_states, answer_mask
    ):
        B, L, D = hidden_states.shape
        start_tophiddens = start_top_hidden_states.view(B, -1, 1, D).expand(
            -1, -1, L, -1
        )  # train: (B,1,L,D) / eval: (B,k,L,D)
        end_hidden_states = torch.cat(
            [
                start_tophiddens,
                hidden_states.view(B, 1, L, D).expand_as(start_tophiddens),
            ],
            dim=-1,
        )  # train: (B,1,L,2D) / eval: (B,k,L,2D)
        end_logits = self.end_predictor(end_hidden_states).squeeze(-1)  # (B,1/k,L)
        end_logits = end_logits.masked_fill(
            ~answer_mask.view(B, 1, L), -float("inf")
        )  # train: (B,1,L) / eval: (B,k,L)
        end_logits = end_logits.squeeze(1)  # train: (B,L) / eval: (B,k,L)

        return end_logits

    def _calculate_answerability_logits(self, hidden_states, start_logits):
        answerability_hidden_states = hidden_states[:, 0, :]  # (B,D)
        start_probs = start_logits.softmax(dim=-1).unsqueeze(-1)  # (B,L,1)
        start_featrues = (start_probs * hidden_states).sum(dim=1)  # (B,D)
        answerability_hidden_states = torch.cat(
            [answerability_hidden_states, start_featrues], dim=-1
        )  # (B,2D)
        answerability_logits = self.answerability_predictor(
            answerability_hidden_states
        )  # (B,1)
        return answerability_logits.squeeze(-1)  # (B,)


# ====================================
# Backbone (Transformer Encoder)
# ====================================


class TSPModel(TSPPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = Embeddings(config)
        if config.embedding_size != config.hidden_size:
            self.embeddings_project = nn.Linear(
                config.embedding_size, config.hidden_size
            )
        self.layers = nn.ModuleList(
            EncoderLayer(config) for _ in range(config.num_hidden_layers)
        )
        self.apply(self._init_weights)

    def forward(
        self,
        input_ids,  # <int>(B,L)
        attention_mask,  # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
        token_type_ids,  # <int>(B,L), 0 / 1 corresponds to a segment A / B token
    ):
        x = self.embeddings(
            input_ids=input_ids, token_type_ids=token_type_ids
        )  # (B,L,E)
        if hasattr(self, "embeddings_project"):
            x = self.embeddings_project(x)  # (B,L,D)

        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask=attention_mask,
            input_shape=input_ids.shape,
            device=input_ids.device,
        )  # (B,1,1,L)

        for layer_idx, layer in enumerate(self.layers):
            x = layer(x, attention_mask=extended_attention_mask)  # (B,L,D)

        return x  # (B,L,D)


class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
        )
        self.position_embeddings = nn.Embedding(
            max(config.max_sequence_length, 512), config.embedding_size
        )
        self.token_type_embeddings = nn.Embedding(2, config.embedding_size)
        self.norm = nn.LayerNorm(config.embedding_size)
        self.dropout = nn.Dropout(config.dropout_prob)

    def forward(
        self,
        input_ids,  # <int>(B,L)
        token_type_ids,  # <int>(B,L), 0 / 1 corresponds to a segment A / B token
    ):
        B, L = input_ids.shape
        embeddings = self.word_embeddings(input_ids)  # (B,L,E)
        embeddings += self.token_type_embeddings(token_type_ids)
        embeddings += self.position_embeddings.weight[None, :L, :]
        embeddings = self.norm(embeddings)  # (B,L,E)
        embeddings = self.dropout(embeddings)  # (B,L,E)
        return embeddings  # (B,L,E)


class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attn_block = BlockWrapper(config, MultiHeadSelfAttention)
        self.transition_block = BlockWrapper(config, FeedForwardNetwork)

    def forward(
        self,
        x,  # (B,L,D)
        attention_mask,  # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
    ):
        x = self.self_attn_block(x, attention_mask=attention_mask)
        x = self.transition_block(x)
        return x  # (B,L,D)


class BlockWrapper(nn.Module):
    def __init__(self, config, sublayer_cls):
        super().__init__()
        self.sublayer = sublayer_cls(config)
        self.dropout = nn.Dropout(config.dropout_prob)
        self.norm = nn.LayerNorm(config.hidden_size)

    def forward(self, x, **kwargs):
        original_x = x
        x = self.sublayer(x, **kwargs)
        x = self.dropout(x)
        x = original_x + x
        x = self.norm(x)
        return x


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mix_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.attention = Attention(config)
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.H = config.num_attention_heads
        self.d = config.hidden_size // self.H

    def forward(
        self,
        x,  # (B,L,D)
        attention_mask,  # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
    ):
        B, L, D, H, d = *x.shape, self.H, self.d
        query, key, value = (
            self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1)
        )  # (B,H,L,d),(B,H,L,d),(B,H,L,d)
        output = self.attention(query, key, value, attention_mask)  # (B,H,L,d)
        output = self.o_proj(output.transpose(1, 2).reshape(B, L, D))  # (B,L,D)
        return output  # (B,L,D)


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = nn.Dropout(config.dropout_prob)

    def forward(
        self,
        query,  # (B,H,L,d)
        key,  # (B,H,L,d)
        value,  # (B,H,L,d)
        attention_mask,  # <int>(B,H,L,L), 0 / -1e4 for tokens that are not attended/ attended
    ):
        B, H, L, d = key.shape
        attention_score = query.matmul(key.transpose(-2, -1))  # (B,H,L,L)
        attention_score = attention_score / math.sqrt(d)  # (B,H,L,L)
        attention_score += attention_mask  # (B,H,L,L)
        attention_probs = attention_score.softmax(dim=-1)  # (B,H,L,L)
        attention_probs = self.dropout(attention_probs)  # (B,H,L,L)
        output = attention_probs.matmul(value)  # (B,H,L,d)
        return output  # (B,H,L,d)


class FeedForwardNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, x):  # (B,L,D)
        x = self.linear1(x)  # (B L,intermediate_size)
        x = F.gelu(x)  # (B,L,intermediate_size)
        x = self.linear2(x)  # (B,L,D)
        return x  # (B,L,D)
