import torch

from dataclasses import dataclass
from typing import Dict, Any, List


@dataclass
class DataCollator:
    eos_token: int
    append_bos: bool

    def __call__(self, batch: List[torch.Tensor]) -> Dict[str, Any]:
        """
        Collate a batch of tensors for training.

        Args:
            batch (List[torch.Tensor]): List of tensors to be collated.
        
        returns:
            Dict[str, Any]: Dictionary contains inputs and labels.
        """
        batch = torch.cat(batch, dim=0)
        sz = batch.size()
        
        if not self.append_bos:
            return {
                "input_ids": batch[:, :-1],
                "labels": batch[:, 1:],
            }

        bos_tokens = batch.new_full((sz[0], 1), self.eos_token)
        input_ids = torch.cat((bos_tokens, batch[:, :-1]), dim=-1)

        data = {
            "input_ids": input_ids,
            "labels": batch,
        }
        return data