# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import bisect
import copy
import logging
from collections import defaultdict
from typing import List, Union

import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset, get_worker_info
from transformers.trainer_pt_utils import LabelSmoother

from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN

IGNORE_TOKEN_ID = LabelSmoother.ignore_index
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


class PackedDataset(IterableDataset):
    def __init__(
        self,
        tokenizer,
        data_rank,
        data_world_size,
        datasets: List,
        dataset_weight: List[int] = None,
        num_images_expected: int = 6,
        max_packed_tokens: int = 32768,
        max_buffer_size: int = 100,
        log_freq: int = 1000000,
        strict_mode: bool = False,
        debug_mode: bool = False,
        replacement: bool = True,
        allow_overflow: bool = True,
        allow_empty_data: bool = False,
        allow_deduplicated_ds_name: bool = False,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.data_rank = data_rank
        self.data_world_size = data_world_size
        self.datasets = datasets
        self.num_images_expected = num_images_expected
        self.max_buffer_size = max_buffer_size
        self.log_freq = log_freq
        self.strict_mode = strict_mode
        self.debug_mode = debug_mode
        self.replacement = replacement
        self.allow_overflow = allow_overflow
        self.allow_empty_data = allow_empty_data

        self.max_packed_tokens = max_packed_tokens

        self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
        self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
        raise ValueError("packdataset")
        import ipdb;ipdb.set_trace()
        assert self.img_start_token_id != self.tokenizer.unk_token_id
        assert self.img_token_id != self.tokenizer.unk_token_id
        assert self.img_end_token_id != self.tokenizer.unk_token_id

        if dataset_weight is None:
            dataset_weight = [1] * len(datasets)
        self.dataset_type = [d.dataset_type for d in self.datasets]

        self.datasets_orig = datasets
        self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight]

        self.datasets = [ds for ds in self.datasets_orig]
        self.dataset_weight = [w for w in self.dataset_weight_orig]

        # lazy init
        self.worker_id = None
        self.worker_state_key = None
        self.dataset_iter_list = None
        self._state_dict = {
            'sample_info': {d.ds_name:0 for d in self.datasets},
        }

        self.worker_custom_infos = None

        ds_name_list = [d.ds_name for d in self.datasets]
        if not allow_deduplicated_ds_name:
            assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}'

        for ds in self.datasets:
            if ds.max_num_images > self.num_images_expected:
                logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}')
                ds.max_num_images = num_images_expected

            if ds.max_tokens > self.max_packed_tokens:
                logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}')
                ds.max_tokens = self.max_packed_tokens

            self._state_dict[ds.ds_name] = {}

        if get_rank() == 0:
            logger.info(
                f'Loaded dataset to pack: {ds_name_list}, '
                f'{self.num_images_expected=}, {self.max_packed_tokens=}, '
                f'{self.replacement=}, {self.allow_overflow=}',
            )

            temp = []
            for ds, ds_w in zip(self.datasets, self.dataset_weight):
                temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%')
            temp = '\n'.join(temp)
            logger.info(
                f'Sampling prob for each dataset:\n{temp}'
            )

        if self.allow_empty_data:
            logger.warning('allow_empty_data is enabled, note that empty data may be generated!')

    def load_state_dict(self, state_dict, custom_infos=None):

        self.worker_custom_infos = custom_infos

        self._state_dict.update(state_dict)
        for ds in self.datasets:
            if ds.ds_name in self._state_dict:
                ds.load_state_dict(self._state_dict[ds.ds_name])
                logger.info(f'{ds.ds_name=} is resumed.')
            else:
                logger.warning(f'{ds.ds_name=} is not resumed.')

    def _should_log(self):
        worker_id = 0 if get_worker_info() is None else get_worker_info().id
        num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers

        worker_id = num_workers * get_rank() + worker_id
        num_workers = num_workers * get_world_size()

        return worker_id == 0

    def next_data(self, current_dataset_idx):
        while True:
            try:
                current_sample = next(self.dataset_iter_list[current_dataset_idx])
                break  # Exit loop if successful
            except StopIteration:
                if self.replacement:
                    # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.')
                    try:
                        self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx])
                        current_sample = next(self.dataset_iter_list[current_dataset_idx])
                        break
                    except:
                        # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
                        self.datasets.pop(current_dataset_idx)
                        self.dataset_iter_list.pop(current_dataset_idx)
                        self.dataset_weight.pop(current_dataset_idx)

                        if len(self.datasets) == 0:
                            raise StopIteration
                        current_dataset_idx = np.random.choice(len(self.datasets))
                else:
                    # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
                    self.datasets.pop(current_dataset_idx)
                    self.dataset_iter_list.pop(current_dataset_idx)
                    self.dataset_weight.pop(current_dataset_idx)

                    if len(self.datasets) == 0:
                        raise StopIteration
                    current_dataset_idx = np.random.choice(len(self.datasets))
            except:
                logger.error('Unexpected error!')
                if len(self.datasets) == 0:
                    raise StopIteration
                current_dataset_idx = np.random.choice(len(self.datasets))

        current_ds_name = self.datasets[current_dataset_idx].ds_name
        current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx

        if self.worker_state_key not in self._state_dict[current_ds_name]:
            self._state_dict[current_ds_name][self.worker_state_key] = {}

        meta_info = current_sample.pop('meta_info', {})
        self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info)
        self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1
        return current_sample

    def find_buffer(self, buffer_list, new_sample):
        # NOTE: use `bisect` to search might be faster

        find = False
        find_idx = -1
        num_images_current = new_sample['pixel_values'].size(0)
        for buffer_idx, buffer in enumerate(buffer_list):
            num_images_buffer = buffer['pixel_values'].size(0)
            if num_images_buffer + num_images_current <= self.num_images_expected:
                num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0)

                if num_merged_tokens <= self.max_packed_tokens:
                    find = True
                    find_idx = buffer_idx
                    break

                if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2:
                    find = True
                    find_idx = buffer_idx

        if find:
            return buffer_list.pop(find_idx)
        return None

    def update_buffer(self, buffer, new_sample):
        if buffer is None:
            new_sample['data_index'] = torch.zeros_like(new_sample['input_ids'])
            return new_sample

        new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item()

        assert buffer.keys() == new_sample.keys()
        for k in buffer:
            buffer[k] = torch.cat([buffer[k], new_sample[k]])
        return buffer

    @staticmethod
    def check_valid(sample_to_check, min_active_tokens_ratio=1/256):
        num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum()
        num_tokens = sample_to_check['labels'].numel()
        return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio

    @staticmethod
    def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id):
        if buffer['input_ids'].size(0) <= max_tokens:
            return [buffer]

        def _image_is_splitted(input_ids, cut_idx):
            is_image_start = input_ids[cut_idx].item() == img_start_token_id
            is_image_token = input_ids[cut_idx].item() == img_token_id
            is_image_end = input_ids[cut_idx].item() == img_end_token_id
            return is_image_start or is_image_token or is_image_end

        def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx):
            assert (right_idx is None) == (right_img_idx is None)

            left_sample = {}
            right_sample = {} if right_idx is not None else None
            for k in sample_to_split:
                if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']:
                    left_sample[k] = sample_to_split[k][:left_idx]
                    if right_sample is not None:
                        right_sample[k] = sample_to_split[k][right_idx:]
                elif k in ['pixel_values', 'image_flags']:
                    left_sample[k] = sample_to_split[k][:left_img_idx]
                    if right_sample is not None:
                        right_sample[k] = sample_to_split[k][right_img_idx:]
                else:
                    raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}')
            return left_sample, right_sample

        splitted_buffer = []
        while buffer['input_ids'].size(0) > max_tokens:
            img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist()
            img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist()
            assert len(img_start_idx_list) == len(img_end_idx_list)

            if _image_is_splitted(buffer['input_ids'], max_tokens):
                cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens)
                if buffer['input_ids'][max_tokens] == img_start_token_id:
                    assert max_tokens == img_start_idx_list[cut_idx]
                    cut_left_idx = img_start_idx_list[cut_idx]
                    cut_left_img_idx = cut_idx
                else:
                    cut_left_idx = img_start_idx_list[cut_idx - 1]
                    cut_left_img_idx = cut_idx - 1
                cut_right_idx = cut_left_idx
                cut_right_img_idx = cut_left_img_idx
            else:
                cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens)
                if cut_img_idx < len(img_start_idx_list):
                    cut_right_idx = img_start_idx_list[cut_img_idx]
                    cut_right_img_idx = cut_img_idx
                else:
                    cut_right_idx = None
                    cut_right_img_idx = None

                cut_left_idx = max_tokens
                cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0)

            left, right = _split(
                sample_to_split=buffer,
                left_idx=cut_left_idx,
                left_img_idx=cut_left_img_idx,
                right_idx=cut_right_idx,
                right_img_idx=cut_right_img_idx,
            )

            assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0)
            if right is not None:
                assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0)

            if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left):
                splitted_buffer.append(left)

            if right is None or right['pixel_values'].size(0) == 0:
                break

            buffer = right
            if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer):
                splitted_buffer.append(buffer)
                break

        logger.debug(
            f'split a sample into {len(splitted_buffer)} samples, '
            f'current max_tokens={max_tokens}'
        )
        return splitted_buffer

    def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer):
        # NOTE: in-place operation

        splitted_buffer = PackedDataset.split_buffer(
            buffer=buffer,
            max_tokens=self.max_packed_tokens,
            img_start_token_id=self.img_start_token_id,
            img_token_id=self.img_token_id,
            img_end_token_id=self.img_end_token_id,
        )

        for each_buffer in splitted_buffer:
            if each_buffer['pixel_values'].size(0) > self.num_images_expected:
                logger.error(
                    f"Find a sample with {each_buffer['pixel_values'].size(0)} images, "
                    f'which exceeds {self.num_images_expected}'
                )
                continue

            if each_buffer['input_ids'].size(0) >= self.max_packed_tokens:
                assert each_buffer['input_ids'].size(0) == self.max_packed_tokens
                buffer_max_len_list.append(each_buffer)
                continue

            find_idx = len(buffer_list)
            num_images_new_sample = each_buffer['pixel_values'].size(0)
            for buffer_idx in range(len(buffer_list)):
                if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample:
                    find_idx = buffer_idx
                    break
            buffer_list.insert(find_idx, each_buffer)

        for i in range(1, len(buffer_list)):
            assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0)

        return buffer_list, buffer_max_len_list

    def pad_buffer(self, buffer):
        if buffer['pixel_values'].size(0) == self.num_images_expected:
            return buffer

        num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0)
        pad_images = torch.stack([
            torch.zeros_like(buffer['pixel_values'][0])
            for _ in range(num_pad_images)
        ])
        pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long)

        buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images])
        buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags])

        return buffer

    def postprocess_buffer(self, buffer, custom_infos=None):
        buffer['worker_state_key'] = self.worker_state_key
        buffer['worker_state_dict'] = self._state_dict
        if custom_infos is not None:
            buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)}
        return buffer

    def print_log(self, iter_idx, buffer_list):
        if iter_idx % self.log_freq != 0:
            return

        if self._should_log():
            logger.info(
                f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}"
            )

    def __iter__(self):
        iter_idx = 0
        buffer_list = []
        buffer_max_len_list = []

        if self._should_log():
            logger.info(f'Begin to iter, {len(buffer_list)=}')

        worker_id = 0 if get_worker_info() is None else get_worker_info().id
        num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers

        worker_id = num_workers * self.data_rank + worker_id
        num_workers = num_workers * self.data_world_size

        rng = np.random.default_rng(seed=worker_id)

        # reset states of each dataset
        self.worker_id = worker_id
        self.worker_state_key = f'work_state_{self.worker_id}'
        self.datasets = [d for d in self.datasets_orig]
        self.dataset_weight = [w for w in self.dataset_weight_orig]
        self.dataset_iter_list = [iter(d) for d in self.datasets]

        for ds in self.datasets:
            # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)):
            ds.worker_id = worker_id
            ds.worker_state_key = f'work_state_{self.worker_id}'
            ds.num_workers = num_workers
            if self._should_log() and worker_id == 0:
                logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}')

        if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos:
            custom_infos = self.worker_custom_infos[self.worker_state_key]
            # buffer list
            if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list):
                buffer_list = custom_infos['buffer_list']
                if self._should_log() and worker_id == 0:
                    logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}')
            # other infos

            # reset
            self.worker_custom_infos = None

        logger.debug(
            f'{self.__class__.__name__} Rank {self.data_rank} '
            f'Worker {worker_id} begin to load data'
        )

        while True:
            self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight]
            current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight)

            try:
                current_sample = self.next_data(current_dataset_idx)
            except:
                logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})')
                while len(buffer_list) > 0:
                    if self.strict_mode:
                        yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)))
                    else:
                        yield self.postprocess_buffer(buffer_list.pop(0))
                logger.info(f'buffer_list is empty! ({len(buffer_list)=})')
                return

            buffer = self.find_buffer(buffer_list, current_sample)
            buffer = self.update_buffer(buffer, current_sample)
            buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer)

            while len(buffer_max_len_list) > 0:
                if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens:
                    logger.debug(
                        f'num tokens of a buffer exceed {self.max_packed_tokens=}, '
                        f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images"
                    )
                if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected:
                    # buffer_max_len_list.pop(0)
                    yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list})
                else:
                    yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list})

            while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected:
                logger.error(
                    f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) "
                    f'is larger than num_images_expected({self.num_images_expected})'
                )
                buffer_list.pop(0)

            while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected:
                if self.debug_mode:
                    debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
                    while True:
                        yield debug_data.copy()

                yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})

            while len(buffer_list) > self.max_buffer_size:
                logger.debug(
                    f'Failed to pack data to exactly {self.num_images_expected} images, '
                    f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images."
                )
                if self.strict_mode:
                    yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list})
                else:
                    yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})

            self.print_log(iter_idx=iter_idx, buffer_list=buffer_list)
            iter_idx += 1

    @staticmethod
    def get_cu_seqlens_and_indexes(
        data_index: torch.LongTensor,  # (seq_len,)
        input_ids: torch.LongTensor,   # (seq_len,)
        labels: torch.LongTensor,   # (seq_len,)
        len2weight: callable,
    ):
        indexes = []
        cu_seqlens = [0]
        loss_weight = []

        start = data_index.min()
        end = data_index.max() + 1
        for i in range(start, end):
            num_tokens = (data_index == i).sum().item()
            indexes.extend(list(range(num_tokens)))
            cu_seqlens.append(cu_seqlens[-1] + num_tokens)
            assert num_tokens > 0

            curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
            assert (curr_data_index == i).all(), data_index

            curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
            num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item()
            loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens)

        assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}'

        loss_weight = torch.tensor(loss_weight, dtype=torch.float32)
        return cu_seqlens, indexes, loss_weight


WARNING_CNT = defaultdict(int)


def packed_collate_fn(
    features,
    data_collator,
    len2weight: callable,
    max_item_length: int,
    micro_num: int = 1,
    loss_reduction_all_gather: bool = False,
    pad_id: int = 0,
):
    if not isinstance(features, list):
        features = [features]

    if len(features) > micro_num:
        raise NotImplementedError(f'{len(features)=} > {micro_num=}')

    if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5:
        logger.warning(
            f'{len(features)=} > {micro_num=}, '
            f'the features will be padded to satisfy micro_num requirement'
        )
        WARNING_CNT['micro_num_warning'] += 1

    # ensure that the len(features) is equal to the required micro_num
    num_features = len(features)
    while len(features) < micro_num:
        features.append(copy.deepcopy(features[0]))
        features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID)

    indexes = []
    cu_seqlens = []
    cu_num_images_list = [0]

    worker_state_key_list = []
    worker_state_dict_list = []
    worker_state_custom_infos_list = []

    batch_lens = [feat['input_ids'].shape for feat in features]
    max_item_length = max_item_length or max(batch_lens)[0]

    num_samples = 0
    num_padding_tokens = 0
    for feat_idx, feat in enumerate(features):
        data_index = feat.pop('data_index')
        curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes(
            data_index=data_index,
            input_ids=feat['input_ids'],
            labels=feat['labels'],
            len2weight=len2weight,
        )

        feat['loss_weight'] = curr_loss_weight

        if feat_idx < num_features:
            num_samples += len(curr_cu_seqlens) - 1

        if curr_cu_seqlens[-1] < max_item_length:
            curr_cu_seqlens.append(max_item_length)
            curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2])))

        indexes.append(torch.tensor(curr_indexes, dtype=torch.long))
        cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32))

        worker_state_key_list.append(feat.pop('worker_state_key'))
        worker_state_dict_list.append(feat.pop('worker_state_dict'))
        worker_state_custom_infos_list.append(feat.pop('custom_infos', None))

        num_padding_tokens += (max_item_length - feat['input_ids'].size(0))
        cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0))

    batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id)
    # convert it to list in case it is converted into bf16
    batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist()
    batch['attention_mask'] = torch.stack(cu_seqlens)
    batch['loss_reduction_all_gather'] = loss_reduction_all_gather
    batch['statistics'] = torch.tensor(
        [
            num_samples,
            num_padding_tokens,
            batch['image_flags'].numel() - batch['image_flags'].sum().item(),
        ],
        dtype=torch.long,
    )
    batch.pop('type_ids')
    return batch
