import logging
from typing import Callable, Iterable, Iterator

import torch
from transformers import BatchEncoding
from transformers.utils import ModelOutput


END_OF_BATCH = -1
logger = logging.getLogger(__name__)


def map_to_model_batches(
    tokenized_batches: Iterable[BatchEncoding],
    model_batch_size: int = 512,
    trim_last_column: bool = True,
) -> Iterator[tuple[BatchEncoding, list[int]]]:
    """Remaps batches to better fit the model's batch size.
    Also tries to reduce the number of inputs fed to the model by
    compressing near-duplicates.
    """
    batch_encoding: dict[str, list[torch.Tensor]] = {
        "input_ids": [],
        "attention_mask": [],
    }
    encoding_indices: dict[tuple, int] = {}
    input_to_model_batch_indices: list[int] = []
    input_sequence_length = -1

    for input_batch in tokenized_batches:
        input_ids: torch.Tensor = input_batch.input_ids
        attention_mask: torch.Tensor = input_batch.attention_mask
        if input_sequence_length == -1:
            input_sequence_length = input_ids.shape[1]
        else:
            assert (
                input_sequence_length == input_ids.shape[1]
            ), "All inputs must have the same sequence length"
        if trim_last_column:
            # We don't need to feed the last column to the model, since
            # we get its probabilities from the previous column.
            # masked_input_ids = masked_input_ids[:, :-1]
            input_ids = input_ids[:, :-1]
            attention_mask = attention_mask[:, :-1]

        for row_idx, (input_row, attention_row) in enumerate(
            zip(input_ids, attention_mask)
        ):
            # Construct a lookup key for the current row
            encoding_key = (
                tuple(token_id.item() for token_id in input_row),
                tuple(mask_entry.item() for mask_entry in attention_row),
            )
            input_row_idx = encoding_indices.get(encoding_key)
            if input_row_idx is None:
                # This is a new input, so we need to add it to the batch
                input_row_idx = len(batch_encoding["input_ids"])
                batch_encoding["input_ids"].append(input_row)
                batch_encoding["attention_mask"].append(attention_row)
                encoding_indices[encoding_key] = input_row_idx
            # Else this input already has an encoding
            # (up to the last token)
            # whose result we can reuse
            input_to_model_batch_indices.append(input_row_idx)

            if row_idx == len(input_ids) - 1:
                # The input batch is exhausted: add
                # a special termination indicator
                input_to_model_batch_indices.append(END_OF_BATCH)

            if len(batch_encoding["input_ids"]) == model_batch_size:
                # We filled up the batch, so we can pass it to the model
                # and clear the buffers
                # OPTIMIZE: if there are inputs later that map to already
                # encoded inputs in an earlier batch, they will be
                # reencoded.
                model_batch = BatchEncoding(
                    {
                        key: torch.stack(value)
                        for key, value in batch_encoding.items()
                    }
                )
                batch_encoding["input_ids"].clear()
                batch_encoding["attention_mask"].clear()
                encoding_indices.clear()
                yield model_batch, input_to_model_batch_indices
                input_to_model_batch_indices = []

    # There is sANONYMOUS a batch left that is not full
    if len(batch_encoding["input_ids"]) > 0:
        model_batch = BatchEncoding(
            {key: torch.stack(value) for key, value in batch_encoding.items()}
        )
        yield model_batch, input_to_model_batch_indices


def remap_to_input_batches() -> (
    Callable[[ModelOutput, list[int]], Iterator[ModelOutput]]
):
    """Map model batches back to the input batches they were derived from"""
    model_outputs: list[ModelOutput] = []
    input_batch_mappings: list[list[int]] = []

    def output_batch_generator(
        output: ModelOutput,
        input_to_model_batch_indices: list[int],
    ) -> Iterator[ModelOutput]:
        nonlocal model_outputs
        nonlocal input_batch_mappings

        model_outputs.append(output)
        input_batch_mappings.append([])

        for i, idx in enumerate(input_to_model_batch_indices):
            if idx == END_OF_BATCH:
                # We've reached the end of an input batch, so we
                # can yield the corresponding remapped model batches
                output_values: dict[str, list] = {}
                for model_batch, batch_indices in zip(
                    model_outputs,
                    input_batch_mappings,
                ):
                    for key, value in model_batch.items():
                        if isinstance(value, torch.Tensor):
                            output_values.setdefault(key, []).append(
                                value[batch_indices]
                            )
                        elif isinstance(value, tuple) and all(
                            isinstance(v, torch.Tensor) for v in value
                        ):
                            output_values.setdefault(key, []).append(
                                tuple(t[batch_indices] for t in value)
                            )
                        else:
                            continue
                if i == len(input_to_model_batch_indices) - 1:
                    # This is the last element in this model batch, so we
                    # can clear all cached model batches
                    model_outputs.clear()
                    input_batch_mappings.clear()
                else:
                    # The current model batch is sANONYMOUS needed to restore
                    # the next input batch, so we keep it.
                    # We only have to keep one batch, since all the previous
                    # ones have already been yielded.
                    model_outputs[:] = [model_outputs[-1]]
                    input_batch_mappings[:] = [[]]
                yield ModelOutput(
                    **{
                        key: (
                            tuple(torch.cat(t) for t in zip(*value))
                            if isinstance(value[0], tuple)
                            else torch.cat(value, dim=0)
                        )
                        for key, value in output_values.items()
                    }
                )
            else:
                input_batch_mappings[-1].append(idx)

    return output_batch_generator
