from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import SequenceClassifierOutput


class ModelForMultipleChoice(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        num_labels: int = 4,
        *inputs,
        **kwargs,
    ):
        super().__init__()
        self.model = model
        self.num_labels = num_labels
        try:
            self.classifier = nn.Linear(self.model.config.hidden_size, 1, bias=False)
        except:
            self.classifier = nn.Linear(4096, 1, bias=False)

    @property
    def config(self):
        return self.model.config

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        span_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        # Assuming batch and outputs are defined and contain the necessary tensors
        # outputs.last_hidden_state shape: (N, L, D)
        # batch["mean_mask"] shape: (N, L, C)

        # Extract the hidden states from the model output
        hidden_states = outputs.last_hidden_state

        # Extract the mask from the batch
        mask = span_mask

        # Use Einstein summation to compute the weighted sum of hidden states according to the mask
        # This results in a tensor of shape (N, C, D)
        choice_embeds = torch.einsum(
            "nlc,nld->ncd", mask.to(hidden_states.dtype), hidden_states
        )

        # Compute the sum of the mask along the sequence length L
        # This results in a tensor of shape (N, C)
        mask_sum = mask.sum(1)
        # Divide the summed embeddings by the mask sum to get the average embeddings
        # Ensure broadcasting by adding a singleton dimension to mask_sum
        choice_embeds = choice_embeds / mask_sum[:, :, None]

        logits = self.classifier(choice_embeds).view(-1, self.num_labels)
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
        else:
            loss = None
        self.prev_batch = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "span_mask": span_mask,
        }
        return SequenceClassifierOutput(loss=loss, logits=logits)
