import logging

import torch
from transformers import BatchEncoding

from .metric import TokenMetric, TokenValues


logger = logging.getLogger(__name__)


def sequence_token_probs(
    probs: TokenValues,
    target_sequences: BatchEncoding,
) -> TokenValues:
    """Extract the probabilities of the tokens present in the sequence
    from the distribution over all token probabilities.
    I.e. given probabilities for every possible token, extract only the
    one probability corresponding to the token that is present in the
    sequence at that index.
    E.g. this would allow you to compute the probability of the sequence
    ['a', 'x', 'y', 'z'] in the token probabilities.

    Args:
        probs: The probability distribution over all tokens assigned
            by the model.
        target_sequences: The sequence(s) of tokens whose probabilities
            should be extracted

    Returns:
        The probabilities for the tokens present in the sequence.
    """
    if len(target_sequences.input_ids.shape) == 1:
        # 1D sequence, replicate across the batch
        input_ids = torch.stack(
            [target_sequences.input_ids] * len(probs.values)
        )
        attention_mask = torch.stack(
            [target_sequences.attention_mask] * len(probs.values)
        )
    else:
        input_ids = target_sequences.input_ids
        attention_mask = target_sequences.attention_mask
        assert len(input_ids) == len(probs.values)
        assert len(attention_mask) == len(probs.values)

    token_probs: list[torch.Tensor] = []
    mask: list[torch.Tensor] = []
    for row_probs, row_token_ids, row_mask in zip(
        probs.values,
        input_ids,
        attention_mask,
    ):
        # TODO: Do this in a batched manner instead of using a loop
        # sequence_token_log_probs = torch.zeros_like(row_probs[1:])
        sequence_token_log_probs = torch.zeros(row_probs.shape[0])
        extract_from = -(torch.sum(row_mask) - 1)
        sequence_token_log_probs[extract_from:] = torch.gather(
            row_probs[extract_from:],
            dim=1,
            # Shift the token ids relative to the probabilties in order to
            # get the probability of the next token
            index=row_token_ids[extract_from:].unsqueeze(-1),
        ).squeeze(-1)
        token_probs.append(sequence_token_log_probs)
        output_mask = torch.zeros_like(
            sequence_token_log_probs, dtype=torch.bool
        )
        output_mask[extract_from:] = 1
        mask.append(output_mask)
    return TokenValues(
        values=torch.stack(token_probs),
        mask=torch.stack(mask),
    )


class SequenceProbMetric(TokenMetric):
    def __init__(self, target_sequences: BatchEncoding):
        super().__init__(["token_probs"])
        self.target_sequences = target_sequences

        self.sequence_probs: list[torch.Tensor]
        self.add_state("sequence_probs", default=[], dist_reduce_fx="cat")
        # self.sequence_masks: list[torch.Tensor]
        # self.add_state("sequence_masks", default=[], dist_reduce_fx="cat")

    def update(self, token_probs: TokenValues) -> None:
        sequence_probs = sequence_token_probs(
            token_probs, self.target_sequences
        )
        self.sequence_probs.append(sequence_probs.values)
        # self.sequence_masks.append(sequence_probs.mask)

    def compute(self) -> torch.Tensor:
        return torch.cat(self.sequence_probs, dim=0)

    def to(self, device: torch.device) -> "SequenceProbMetric":
        self.target_sequences = self.target_sequences.to(device)
        return super().to(device)


class CorrectnessMetric(TokenMetric):
    def __init__(self, target_sequences: BatchEncoding):
        super().__init__(["token_probs"])
        self.target_sequences = target_sequences
        # target_shape = self.target_sequences.attention_mask.shape
        # assert (
        #     self.target_sequences.attention_mask.sum()
        #     == target_shape[0] * target_shape[1]
        # ), "We only support equal length strings atm"

        self.prediction_correctness: list[torch.Tensor]
        self.add_state(
            "prediction_correctness", default=[], dist_reduce_fx="cat"
        )

    def update(self, token_probs: TokenValues) -> None:
        max_prob_token_ids = token_probs.values.argmax(dim=-1)
        predictions_correct = (
            max_prob_token_ids == self.target_sequences.input_ids[:, 1:]
        ).to(torch.long)
        self.prediction_correctness.append(predictions_correct)

    def compute(self) -> torch.Tensor:
        return torch.cat(self.prediction_correctness, dim=0)


class LossMetric(TokenMetric):
    def __init__(self, target_sequences: BatchEncoding):
        super().__init__(["logits"])
        self.target_sequences = target_sequences
        # target_shape = self.target_sequences.attention_mask.shape
        # assert (
        #     self.target_sequences.attention_mask.sum()
        #     == target_shape[0] * target_shape[1]
        # ), "We only support equal length strings atm"

        self.losses: list[torch.Tensor]
        self.add_state("losses", default=[], dist_reduce_fx="cat")

    def update(self, logits: torch.Tensor) -> None:
        loss = torch.nn.functional.cross_entropy(
            logits.swapaxes(1, 2),
            self.target_sequences.input_ids[:, 1:],
            # Get a loss value for each element
            reduction="none",
        )
        # loss = loss.mean(dim=-1).reshape(-1)
        self.losses.append(loss)

    def compute(self) -> torch.Tensor:
        return torch.cat(self.losses, dim=0)
