from __future__ import annotations

from typing import List

import torch


def extract_continuations(
    generated_ids: torch.Tensor, input_ids: torch.Tensor
) -> List[List[int]]:
    """Extract only the generated continuations after the original (padded) input.

    This assumes the typical Hugging Face generate() behavior that returns the
    full sequence consisting of the original input_ids (including padding) followed
    by the newly generated tokens. The correct cut position is therefore the
    original padded sequence length, which is input_ids.shape[1]. This holds for
    both left and right padding, since the collator pads all samples in the batch
    to the same sequence length.

    Args:
        generated_ids: Tensor of shape (batch, seq_len + new_tokens)
        input_ids: Tensor of shape (batch, seq_len) used to produce generated_ids

    Returns:
        A ragged list of token id lists, one per batch item, containing only the
        generated continuation tokens (after the original input).
    """
    assert generated_ids.dim() == 2 and input_ids.dim() == 2, "expected (B, T) tensors"
    assert generated_ids.shape[0] == input_ids.shape[0], "batch sizes must match"

    seq_len = int(input_ids.shape[1])
    batch_size = int(generated_ids.shape[0])

    continuations: List[List[int]] = []
    for i in range(batch_size):
        cont = generated_ids[i, seq_len:].tolist()
        continuations.append(cont)
    return continuations
