from typing import List, Optional, Union

from torch import LongTensor
from transformers import PreTrainedTokenizer


def postprocess_generation_ids(
    input_ids: LongTensor,
    output_ids: LongTensor,
    num_return_sequences: int,
    tokenizer: Optional[PreTrainedTokenizer] = None,
    pad_token_ids: Optional[int] = None,
) -> List[List[Union[str, List[int]]]]:
    outputs = []
    for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)):
        sub_output_ids = output_ids[start : start + num_return_sequences]
        sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0) :]
        if tokenizer:
            decoded_bach = (
                generated_text
                for generated_text in tokenizer.batch_decode(sub_generated_ids, clean_up_tokenization_spaces=True)
            )
            decoded_bach = list(decoded_bach)
            outputs.append(decoded_bach)
        else:
            sub_generated_ids = sub_output_ids.cpu().numpy().tolist()
            for i, one_sub_generated_ids in enumerate(sub_generated_ids):
                if pad_token_ids is not None and pad_token_ids in one_sub_generated_ids:
                    one_sub_generated_ids = one_sub_generated_ids[: one_sub_generated_ids.index(pad_token_ids)]
                sub_generated_ids[i] = one_sub_generated_ids
            outputs.append(sub_generated_ids)

    return outputs


__all__ = ["postprocess_generation_ids"]
