"""Common stuff for multiple-choice question answering using LMs."""
import dataclasses
from typing import Dict, Optional, Sequence, Tuple

import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from npeff_torch.models import model_utils

###############################################################################


def _make_answer_label_token_ids(
    tokenizer: PreTrainedTokenizer,
    answer_labels: Sequence[str],
    answer_label_prefix: str,
) -> Tuple[int, ...]:
    """Returns the unique first token id of each answer."""
    token_ids = tuple(
        tokenizer.encode(answer_label_prefix + answer_label)[0]
        for answer_label in answer_labels
    )
    if len(set(token_ids)) != len(token_ids):
        raise ValueError('The answer labels must have a unique first token when tokenized.')
    return tuple(token_ids)


###############################################################################


# Focus just on computing logits here.

# Need list of answer labels (like A, B, ...), givne tokenizer, make sure their first token are unique

# EOS stuff???

# Space stuff? Whether to prefix the labels with newline, space, nothing, ...

@dataclasses.dataclass
class _DummyModelOutput:
    logits: torch.Tensor


class LmMcqaLogitsComputer:

    def __init__(
        self, *,

        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,

        answer_labels: Sequence[str],
        # Depending on how the task is prompted and how the tokenizer works, we might
        # wish to prefix each answer label with a space or something.
        answer_label_prefix: str = '',

        device: Optional[torch.device],
    ):
        self._model = model
        self._device = device

        self.answer_labels = answer_labels
        self.answer_label_prefix = answer_label_prefix

        self._answer_label_token_ids = torch.tensor(
            _make_answer_label_token_ids(tokenizer, answer_labels, answer_label_prefix),
            dtype=torch.int64,
            device=device,
        )

    # def compute_logits(
    #     self,
    #     batch: Dict[str, torch.Tensor],
    #     device: Optional[torch.device] = None,
    # ) -> torch.Tensor:
    #     """The logits are computed by assuming that we restrict the probability distribution 
    #     to the token ids of interest and normalize the resultant distribution."""
    #     if device is None:
    #         device = self._device

    #     # raw_log_probs.shape = [batch, sequence, vocab]
    #     raw_log_probs = torch.log_softmax(model_utils.compute_logits(self._model, batch, device), dim=-1)

    #     # Get the positions of the tokens where to predict. Assumes that all
    #     # padding happens at the end of the sequence.
    #     token_positions = (batch['attention_mask'].to(raw_logits.device) != 0).type(torch.int64).sum(dim=-1)
    #     token_positions = token_positions - 1

    #     batch_size, sequence_length, vocab_size = raw_log_probs.shape
    #     batch_index = torch.arange(batch_size, dtype=torch.int64, device=raw_logits.device)

    #     token_position_flat_offsets = batch_index * (sequence_length * vocab_size) + token_positions * vocab_size
    #     flat_indices = (token_position_flat_offsets[:, None] + self._answer_label_token_ids[None, :]).view(-1)

    #     log_probs = raw_log_probs.view(-1)[flat_indices].view(batch_size, self._answer_label_token_ids.numel())
    #     log_probs = log_probs - torch.logsumexp(log_probs, dim=-1, keepdim=True)
    #     return log_probs

    def compute_logits(
        self,
        batch: Dict[str, torch.Tensor],
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """The logits are computed by assuming that we restrict the probability distribution 
        to the token ids of interest and normalize the resultant distribution."""
        if device is None:
            device = self._device

        # raw_logits.shape = [batch, sequence, vocab]
        raw_logits = model_utils.compute_logits(self._model, batch, device)

        # Get the positions of the tokens where to predict. Assumes that all
        # padding happens at the end of the sequence.
        token_positions = (batch['attention_mask'].to(raw_logits.device) != 0).type(torch.int64).sum(dim=-1)
        token_positions = token_positions - 1

        return _extract_from_sequence_logits(raw_logits, token_positions, self._answer_label_token_ids)

    #######################################################
    # Hacks to be able to call this like a model.

    def __call__(self, **kwargs) -> _DummyModelOutput:
        return _DummyModelOutput(logits=self.compute_logits(kwargs))

    def eval(self):
        return self._model.eval()

    def named_parameters(self):
        return self._model.named_parameters()

    def zero_grad(self):
        return self._model.zero_grad()


###############################################################################


# Put into a separate function for testing purposes.
def _extract_from_sequence_logits(
    # shape = [batch, sequence, vocab], dtype=float32
    raw_logits: torch.Tensor,
    # shape = [batch], dtype=int64
    token_positions: torch.Tensor,
    # shape = [n_answer_labels], dtype=int64
    answer_label_token_ids: torch.Tensor,
) -> torch.Tensor:
    # ret.shape = [batch, n_answer_labels]
    batch_size, sequence_length, vocab_size = raw_logits.shape
    batch_index = torch.arange(batch_size, dtype=torch.int64, device=raw_logits.device)

    token_position_flat_offsets = batch_index * (sequence_length * vocab_size) + token_positions * vocab_size
    flat_indices = (token_position_flat_offsets[:, None] + answer_label_token_ids[None, :]).view(-1)

    return raw_logits.view(-1)[flat_indices].view(batch_size, answer_label_token_ids.numel())
