"""Multiple choices as suffixes for a language-model."""
import dataclasses
from typing import Dict, Optional, Sequence, Tuple

import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from npeff_torch.models import model_utils

###############################################################################

_NORMALIZATION_OPTIONS = (
    # Divide the log probability of a choice by the number of tokens in the choice.
    'choice_token_length',
)

###############################################################################


@dataclasses.dataclass
class _DummyModelOutput:
    logits: torch.Tensor


class LmSuffixMcLogitsComputer:

    def __init__(
        self, *,
        model: PreTrainedModel,
        device: Optional[torch.device],

        normalization: Optional[str] = None,
    ):
        if normalization is not None and normalization not in _NORMALIZATION_OPTIONS:
            raise ValueError(f'Invalid normalization: {normalization}')

        self._model = model
        self._device = device

        self._normalization = normalization

    def compute_logits(
        self,
        batch: Dict[str, torch.Tensor],
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Required keys in batch:
            - "input_ids": shape=[batch, n_options, sequence]
            - "attention_mask": shape=[batch, n_options, sequence]
            - "context_length": shape=[batch]
        """
        if device is None:
            device = self._device

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        context_length = batch["context_length"].to(device)

        batch_size, n_options, sequence_length = input_ids.shape

        # Need to flatten the "batch" dimensions to call the model.
        batch2 = {
            'input_ids': input_ids.view(-1, input_ids.shape[-1]),
            'attention_mask': attention_mask.view(-1, attention_mask.shape[-1]),
        }

        raw_log_probs = torch.log_softmax(model_utils.compute_logits(self._model, batch2, device), dim=-1)
        # raw_log_probs.shape = [batch, n_options, sequence, vocab]
        raw_log_probs = raw_log_probs.view(*input_ids.shape[:2], *raw_log_probs.shape[-2:])

        labels = torch.roll(input_ids, -1, dims=-1)
        # prediction_log_probs.shape = [batch, n_options, sequence]
        prediction_log_probs = torch.squeeze(torch.gather(raw_log_probs, -1, labels[..., None]), dim=-1)

        mask = (context_length[:, None] - 1) <= torch.arange(sequence_length, device=device)[None, :]
        mask = mask[:, None, :].repeat(1, n_options, 1)
        # TODO: Might need to shift or decrease the contexts length by
        mask &= (attention_mask != 0)

        log_probs = torch.einsum('bos,bos->bo', prediction_log_probs, mask.to(prediction_log_probs.dtype))
        
        if self._normalization == 'choice_token_length':
            option_lengths = (attention_mask != 0).type(context_length.dtype).sum(dim=-1)
            choice_lengths = option_lengths - context_length[:, None]
            # log_probs = torch.log_softmax(log_probs, dim=-1)
            log_probs = log_probs / choice_lengths.type(log_probs.dtype)

        return log_probs

    #######################################################
    # Hacks to be able to call this like a model.

    def __call__(self, **kwargs) -> _DummyModelOutput:
        return _DummyModelOutput(logits=self.compute_logits(kwargs))

    def eval(self):
        return self._model.eval()

    def named_parameters(self):
        return self._model.named_parameters()

    def zero_grad(self):
        return self._model.zero_grad()
