import ctypes
import importlib
import json
import logging
import os
import random
from collections import defaultdict, OrderedDict
from multiprocessing.shared_memory import SharedMemory
from typing import Optional, Callable

import numpy as np
import torch
from torch.utils.data import IterableDataset
from numpy.typing import NDArray
from transformers import PreTrainedTokenizerBase

from .indexed_dataset import PrefetchDecodeDataset
from .utils.bitset import BitSet
from .utils.vdc_sampling import van_der_corput, van_der_corput_sampling_gen

logger = logging.getLogger(__name__)
IGNORE_TGT = -100


def load_dataset_cfgs(cfg_path, cfg_json_str=None):
    if cfg_json_str is not None:
        cfgs = json.loads(cfg_json_str)
    else:
        with open(cfg_path, "r", encoding="utf-8") as fin:
            cfgs = json.load(fin)
    transform_basedir = os.path.dirname(os.path.abspath(cfg_path))

    path_dict = None
    platform_config_path = os.getenv("PLATFORM_CONFIG_PATH")
    try:
        assert platform_config_path is not None, "Platform config is None"
        with open(platform_config_path, "r") as f:
            platform_cfg = json.load(f)
        path_dict = platform_cfg["dataset_map"]
        print(f"Loaded jeeves platform config from '{platform_config_path}', update dataset paths...")
    except Exception as e:
        print(f"Failing to load jeeves platform config '{platform_config_path}', error message:\n{str(e)}")

    task_name2dataset_name = dict()
    for idx, cfg in enumerate(cfgs):
        assert "dataset_name" in cfg and isinstance(cfg["dataset_name"], str)
        assert "task_name" in cfg and isinstance(cfg["task_name"], str)
        # to be delibrately annoying :)
        if cfg["task_name"] in task_name2dataset_name:
            raise ValueError(
                f"task_name '{cfg['task_name']}' in dataset '{cfg['dataset_name']}'"
                f"has already been used in '{task_name2dataset_name[cfg['task_name']]}'."
            )
        task_name2dataset_name[cfg["task_name"]] = cfg["dataset_name"]

        assert "path" in cfg and isinstance(cfg["path"], str)
        # if path_dict is not None:
        # cfg["path"] = os.path.join(path_dict[cfg["dataset_name"]], cfg["path"])

        # dealing with optional configs
        if "weight" in cfg:
            assert isinstance(cfg["weight"], (float, int))
        else:
            cfg["weight"] = 1.0

        if "oversize_rule" in cfg:
            assert cfg["oversize_rule"] in ("drop", "head", "segment")
        else:
            cfg["oversize_rule"] = "segment"

        if "transforms" in cfg:
            assert isinstance(cfg["transforms"], str)
            # dealing with relative path
            if not cfg["transforms"].startswith("/"):
                cfg["transforms"] = os.path.join(transform_basedir, cfg["transforms"])
            if not cfg["transforms"]:
                cfg["transforms"] = None
        else:
            cfg["transforms"] = None

        if "incontext_weight" in cfg:
            assert isinstance(cfg["incontext_weight"], (list, tuple))
        else:
            cfg["incontext_weight"] = [1.0]
        cfg["id"] = idx
        # dataset and iterator will be built
    return cfgs


def tokenize(
    data,
    tokenizer: PreTrainedTokenizerBase,
    max_length: int,
    add_bos: bool = False,
    add_eos: bool = True,
    repeat_data: bool = False,
    discard_leftover: bool = False,
):
    '''
    This tokenizes text into IDs, sets the labels to -100 (to be ignored
    in loss computation). When `oversize_rule="segment"`, this will segment
    a given text into chunks if the tokenized sequence exceeds `max_length`.

    Args:
        data (dict): A dictionary with two keys, "input" and "output". During
            causal LM, "input" will be ignored (so it should be "").
        max_length (int): Specifies the length of each segment.
    Returns:
        A generator that yields (source IDs, target IDs), both are list of ints.
    '''
    if data is None:
        yield from ()
        return
    if "output" not in data or not data["output"]:
        yield from ()
        return
    if "input" not in data:
        data["input"] = ""

    # Mamba 的 tokenizer 默认不会添加 BOS 和 EOS。
    if discard_leftover:
        encoding = tokenizer(data['output'], max_length=max_length, truncation=True)
    else:
        encoding = tokenizer(data['output'], max_length=2**20, truncation=True)
    token_ids = encoding.input_ids
    if add_bos:
        token_ids = [tokenizer.bos_token_id] + token_ids
    if add_eos:
        token_ids = token_ids + [tokenizer.eos_token_id]
    # token_ids = list(range(50))
    if len(token_ids) > max_length:
        # Generate multiple chunks of length `max_length`.
        for i in range(0, len(token_ids), max_length):
            cur_token_ids = token_ids[i : i + max_length]
            yield cur_token_ids, cur_token_ids
    else:
        yield token_ids, token_ids


class SegmentedDataset(IterableDataset):
    '''
    A wrapper around `PrefetchDecodeDataset` and a tokenizer. It is an iterable that
    loads text from the dataset and tokenize it, and returns segments of `max_length`
    in each iteration.
    '''
    def __init__(
        self,
        cfg,
        tokenizer,
        num_processes: int,
        local_process_index: int,
        transform_func: Callable,
        max_length: int = 1024,
        nthreads: int = 1,
        prefetch_slice: int = 3,
        slice_size: int = 500,
        do_compact: bool = False,
        discard_leftover: bool = False,
    ):
        super().__init__()
        self.num_processes = num_processes
        self.local_process_index = local_process_index
        self.tokenizer = tokenizer
        self.do_compact = do_compact
        # self.segment = functools.partial(
        #     cricket_data2ids, tokenizer=tokenizer, max_length=max_length, do_compact=do_compact
        # )
        self.cfg = cfg
        self.max_length = max_length
        self.nthreads = nthreads
        self.transform_func = transform_func
        self.prefetch_slice = prefetch_slice
        self.slice_size = slice_size
        self.abs_weight = cfg.get("abs_weight", None)
        self.task_name = cfg["task_name"]
        self.dataset_name = cfg["dataset_name"]
        self.oversize_rule = cfg["oversize_rule"]
        self.dataset = PrefetchDecodeDataset(
            num_processes=num_processes, 
            local_process_index=local_process_index,
            path=cfg["path"],
            allow_repeat=cfg.get("allow_repeat", False)
        )
        self.exhausted = False
        self.iterator = None
        self.discard_leftover = discard_leftover

        self.counter = 0
        self.allow_repeat = cfg.get("allow_repeat", True)
        self.used = BitSet()
        self.init_ave_tokens()

    def init_ave_tokens(
        self,
    ):
        rank = self.local_process_index
        try:
            shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{rank}')
        except FileNotFoundError:
            print("Create Shared Memory {}".format(f'ave_tokens_{self.task_name.replace("/", "_")}_{rank}'))
            shm = SharedMemory(
                create=True,
                size=ctypes.sizeof(ctypes.c_float),
                name=f'ave_tokens_{self.task_name.replace("/", "_")}_{rank}',
            )

        # 使用共享内存
        shared_value = ctypes.c_float.from_buffer(shm.buf)
        shared_value.value = self.cfg.get(
            "avg_tokens", self.cfg.get("ave_tokens", self.cfg.get("ave_tokens_per_line", -1))
        )
        # 不再需要 shared_value 时，删除引用
        del shared_value

        # 现在可以安全地关闭共享内存
        shm.close()
        print("Init ave_tokens for task {}: {}".format(self.task_name, self.ave_tokens))

    @property
    def ave_tokens(
        self,
    ):
        rank = self.local_process_index
        existing_shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{rank}')  # -1 # default length
        shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
        tmp = shared_value.value
        del shared_value
        existing_shm.close()
        return tmp

    def ave_tokens_update(self, length: int):
        rank = self.local_process_index
        existing_shm = SharedMemory(name=f'ave_tokens_{self.task_name.replace("/", "_")}_{rank}')  # -1 # default length
        shared_value = ctypes.c_float.from_buffer(existing_shm.buf)
        if shared_value.value < 0:
            shared_value.value = float(length)
        else:
            shared_value.value = 0.98 * shared_value.value + 0.02 * length
        del shared_value
        existing_shm.close()

    def size(self):
        return self.dataset.size()

    def __iter__(self):
        self.get_iterator()
        return self

    def reset(self):
        rank = self.local_process_index
        self.exhausted = False
        if self.iterator is not None:
            self.iterator.close()
            self.iterator = None
        self.used = BitSet()
        print("Rank {}, Reset dataset:{} done.".format(rank, self.dataset_name))

    def transform(self, data: dict) -> dict:
        weight = np.array(self.cfg["incontext_weight"], dtype=np.float32)
        weight = weight / weight.sum()
        num_incontext = np.random.choice(weight.shape[0], p=weight)
        return self.transform_func(data, num_incontext, random.Random())

    def segment_iterate(self, sample_iter):
        for index, data in self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used):
            # for src_ids, tgt_ids in self.segment(self.transform(data)):
            data = self.transform(data)
            for input_ids, labels in tokenize(
                data,
                self.tokenizer,
                self.max_length,
                discard_leftover=self.discard_leftover,
            ):
                # print("segment_iterate", input_ids, labels)
                # self.ave_tokens_update(len(input_ids))  # 0 for input ids
                yield input_ids, labels, index

    def get_iterator(self):
        # make the dataset itself an iterator
        sample_iter = self.dataset.sliced_iterate(self.nthreads, self.prefetch_slice, self.slice_size, self.used)
        return self.segment_iterate(sample_iter)

    def __next__(self):
        # advance the task iterator
        if self.iterator is None:
            self.iterator = self.get_iterator()
        try:
            return next(self.iterator)
        except StopIteration:
            self.exhausted = True
            return None

    def load_state_dict(self, state_dict):
        if state_dict.get("exhausted", False):
            self.exhausted = True
            self.used = BitSet()
        else:
            used = state_dict.get("used", BitSet())
            if len(used) == len(self.dataset):
                self.exhausted = True
                self.used = BitSet()
            else:
                self.exhausted = False
                self.used = used
        self.ave_tokens_update(state_dict.get("ave_tokens", -1))

    def state_dict(self):
        if len(self.used) == len(self.dataset):
            return dict(exhausted=True, used=BitSet(), ave_tokens=self.ave_tokens)
        else:
            return dict(exhausted=False, used=self.used, ave_tokens=self.ave_tokens)

    def update_state(self, indice):
        self.used.update(indice)


class MixedDataset(IterableDataset):
    """
    A wrapper around multiple `SegmentedDataset`, for loading data
    from different sources/tasks.
    """

    def __init__(
        self,
        cfg_path: str,
        cfg_json_str,
        tokenizer,
        max_length: int,
        num_processes: int,
        process_index: int,
        local_process_index: int,
        weight_by_size: bool = True,
        nthreads: int = 5,
        prefetch_slice: int = 100,
        parallel_loading: bool = False,
        vdc_sampling: bool = False,
        update_weights_frequency: int = 1,
        seed: int = 42,
        discard_leftover: bool = False,
    ):
        super().__init__()
        self.set_seed(seed + process_index)
        # self.accelerator = accelerator
        self.rank = local_process_index
        self.num_processes = num_processes
        self.weight_by_size = weight_by_size
        self.tokenizer = tokenizer
        self.eos_token_id: int = self.tokenizer.eos_token_id
        self.bos_token_id: int = self.tokenizer.bos_token_id
        self.path2transform = dict()
        self.name_to_dataset: OrderedDict[str, SegmentedDataset] = OrderedDict()
        self.nthreads = nthreads
        self.prefetch_slice = prefetch_slice
        self.discard_leftover = discard_leftover
        # useful for indexing
        self.datasets: list[SegmentedDataset] = []
        self.names: list[str] = []
        # ending of iteration
        self.remain = 0
        self.max_length = max_length
        self.vdc_sampling = vdc_sampling
        if self.vdc_sampling:
            self._vdc_values = [van_der_corput(i) for i in range(10**6)]  # 精度提高 10^{-6}
            self.vdc_gen = van_der_corput_sampling_gen(self._vdc_values)

        self.update_weights_frequency = update_weights_frequency

        self.path2transform = dict()

        cfgs = load_dataset_cfgs(cfg_path, cfg_json_str)
        _sum_weight = sum([cfg["abs_weight"] for cfg in cfgs])
        _weights = {cfg["task_name"]: cfg["abs_weight"] / _sum_weight for cfg in cfgs}
        print(f"Absolute Weight of DataSet: {_weights = }")

        if parallel_loading:
            raise NotImplementedError("Parallel loading not supported.")
            # self.parallel_load(cfgs, max_workers=None)
        else:
            self.sequential_load(cfgs)

        self.weights = None
        self.update_weights()

    def set_seed(self, seed: int):
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    def load_task(self, cfg: dict) -> SegmentedDataset:
        logger.info(f"Loading {cfg['path']}")
        transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
        task = SegmentedDataset(
            cfg,
            self.tokenizer,
            num_processes=self.num_processes,
            local_process_index=self.rank,
            max_length=self.max_length,
            transform_func=transform_func,
            nthreads=self.nthreads,
            prefetch_slice=self.prefetch_slice,
            do_compact=cfg.get("do_compact", False),  # dataset level do_compact
        )
        return task

    def sequential_load(self, cfgs: list[dict]):
        self.cfgs = cfgs
        for cfg in cfgs:
            # python3.7 and later preserves insertion order to dictionary
            logger.info(f"Loading {cfg['path']}")

            transform_func = self.get_transform_func(cfg["task_name"], cfg["transforms"])
            seg_dataset = SegmentedDataset(
                cfg,
                self.tokenizer,
                num_processes=self.num_processes,
                local_process_index=self.rank,
                max_length=self.max_length,
                transform_func=transform_func,
                nthreads=self.nthreads,
                prefetch_slice=self.prefetch_slice,
                do_compact=cfg.get("do_compact", False),  # dataset level do_compact
                discard_leftover=self.discard_leftover,
            )
            self.name_to_dataset[seg_dataset.task_name] = seg_dataset
            self.datasets.append(seg_dataset)
            self.names.append(seg_dataset.task_name)
            self.remain += 1
        self.weights = None
        self.update_weights()

    def load_state_dict(self, state_dict: dict) -> list:
        missing_keys = []
        for name, task in self.name_to_dataset.items():
            if name in state_dict:
                task.load_state_dict(state_dict[name])
            else:
                missing_keys.append(name)
        self.update_weights()
        return missing_keys

    def save_state_dict(self, path: str):
        state_dict = {}
        for name, task in self.name_to_dataset.items():
            _state_dict = task.state_dict()
            if isinstance(_state_dict["used"], BitSet):
                bitset = _state_dict["used"]
                _file_name = bitset.save(path)
                _state_dict["used"] = _file_name  # type: ignore
                state_dict[name] = _state_dict
            else:
                state_dict[name] = task.state_dict()
        torch.save(state_dict, path)
        logger.info("Dataset state saved")

    def update_states(self, task_ids, indice):
        is_dict = isinstance(indice, dict)
        uniq = torch.unique(task_ids)
        for idx in uniq:
            idx = idx.item()
            indexes = indice[idx] if is_dict else indice[task_ids == idx].tolist()
            self.datasets[idx].update_state(indexes)

    def get_transform_func(self, module_name: str, transform_script_path: str):
        if transform_script_path is None:
            # allow null transform
            return lambda data, num_incontext, rand: data
        module_name = "cpm_live.transforms.{}".format(module_name)
        if transform_script_path not in self.path2transform:
            loader = importlib.machinery.SourceFileLoader(module_name, transform_script_path)  # type: ignore
            spec = importlib.util.spec_from_loader(loader.name, loader)  # type: ignore
            if spec is None:
                raise RuntimeError("Spec is none! {}".format(module_name))
            mod = importlib.util.module_from_spec(spec)  # type: ignore
            self.path2transform[transform_script_path] = {
                "loader": loader,
                "module": mod,
                "last_mtime": 0,
            }
        transform_script_info = self.path2transform[transform_script_path]
        curr_mtime = float(transform_script_info["loader"].path_stats(transform_script_path)["mtime"])
        if curr_mtime > transform_script_info["last_mtime"]:
            transform_script_info["last_mtime"] = curr_mtime
            transform_script_info["loader"].exec_module(transform_script_info["module"])
        transform_func = getattr(transform_script_info["module"], "transform", None)
        if transform_func is None:
            raise NotImplementedError("Find no transform funcion in script '{}'".format(transform_script_path))
        return transform_func

    def update_weights(self):
        task0 = self.datasets[0]
        if task0.abs_weight is not None:  # 这一份config是指定绝对比例的
            weights = []
            for task in self.datasets:
                if task.exhausted:
                    weights.append(0)
                else:
                    if task.ave_tokens == -1:
                        weights.append(task.abs_weight / self.max_length)
                    else:
                        weights.append(task.abs_weight / task.ave_tokens)
            weights = np.array(weights)
        else:
            weights = np.array([0 if task.exhausted else task.weight for task in self.datasets])  # type: ignore
            if self.weight_by_size:
                sizes = np.array([task.size() for task in self.datasets], dtype=np.float32)
                weights *= sizes
        self.weights = weights / weights.sum()

    def __iter__(self):
        # Get iterators
        self.iterators = []
        for task in self.datasets:
            self.iterators.append(task.get_iterator())
        return self

    def __next__(self):
        step = 1
        while True:
            # Randomly choose one dataset
            rank = self.rank  # type: ignore
            if self.remain == 0:
                print(f"Rank {rank}, All data exhausted!")
                raise StopIteration
            if self.vdc_sampling:
                idx = next(self.vdc_gen)(self.weights)
            else:
                idx = np.random.choice(len(self.weights), p=self.weights)  # type: ignore

            data = next(self.iterators[idx])
            if step % self.update_weights_frequency == 0:
                self.update_weights()
            if data is None:
                dataset_name = self.datasets[idx].dataset_name
                if self.datasets[idx].allow_repeat:
                    print(f"Rank {rank}, dataset {dataset_name} exhausted, resetting it (next epoch)...")
                    self.datasets[idx].reset()
                else:
                    print(f"Rank {rank}, dataset {dataset_name} exhausted, closing it.")
                    self.datasets[idx].exhaust = True
                    self.remain -= 1
                continue
            step += 1
            return dict(
                task_id=idx,
                input=data[0],
                target=data[1],
                index=data[2],
                is_long=self.datasets[idx].cfg.get("is_long", False),
            )


class PackedMixedDataset(IterableDataset):
    """
    A wrapper around `MixedIndexedDataset` that returns a batch of
    multiple sequences into in every iteration.
    """

    def __init__(
        self,
        mixed_dataset: MixedDataset,
        packing_count: int,
        # batch_size: int,
        max_length: int,
        pose_prob: float = 0.0,
        pose_scaling_factor: float = 1.0,
        compact: bool = False,
        repeat_data: bool = False,
    ):
        '''
        Args:
            mixed_dataset (MixedDataset): 
            packing_count (int): The number of sequences to be concatenated.
            max_length (int): The number of tokens in each sequence.
            repeat_data (bool): Whether to repeat data in a shuffle manner.
        '''
        self.max_length = max_length
        # self.batch_size = batch_size
        self.packing_count = packing_count
        self.max_total_length = max_length * packing_count
        self.repeat_data = repeat_data
        
        print(f"[UnpadBatchedMixedDataset] {self.max_total_length = }, {packing_count = }, {max_length = }")
        # self.batch_size = 1
        # setting compact=True concats segments orignated from the same input
        # into a long sequence. the relative order of segments should be preserved
        # in mixed_dataset, e.g.,
        # - ok: task1_seg1, task2_seg1, task1_seg2, task1_seg3
        # - not_ok:  task1_seg1, task1_seg3, task2_seg1, task1_seg2
        self.compact = compact

        self.total_length = 0
        self.task2seqs = defaultdict(list)
        self.mixed_dataset = mixed_dataset
        self._max_length = max_length
        self._pose_prob = pose_prob
        self._pose_scaling_factor = pose_scaling_factor
        if self._pose_prob > 0.0:
            self._scaled_max_length = int(self.max_total_length * self._pose_scaling_factor)
        else:
            self._scaled_max_length = max_length

    def put(self, sample: dict):
        self.total_length += len(sample["target"])
        task_id = sample["task_id"]
        if self.compact and self.task2seqs[task_id]:
            last = self.task2seqs[task_id][-1]
            if last["target"][-1] != self.mixed_dataset.eos_token_id:
                # concatenate sequantial segments for longer context modeling: why not?
                last["input"].extend(sample["input"])
                last["target"].extend(sample["target"])
                return
        self.task2seqs[task_id].append(sample)

    def _pose_preprocess(
        self,
        input_ids: NDArray[np.int32],
    ) -> NDArray[np.int32]:
        """[PoSE](https://arxiv.org/abs/2309.10400v2)
        GitHub implementation: https://github.com/dwzhu-pku/PoSE/blob/master/src/train_pose.py#L156
        """
        len_chunk = min(len(input_ids), self._max_length)
        len_input = len(input_ids)
        # Chunk input randomly to fit max_length if needed
        lt1 = 0
        rt1 = random.randint(0, (len_chunk + 1) // 2)  # Fist chunk only contains 1/2 tokens at most
        rt2 = random.randint(lt1 + len_chunk, len_input)  # Second chunk can randomly shift when not filled max_length
        lt2 = rt2 - (len_chunk - (rt1 - lt1))  # assure all tokens are used
        chunked_input_ids = np.concatenate([input_ids[lt1:rt1], input_ids[lt2:rt2]], axis=-1)
        # Generate PoSE position ids
        position_ids = np.arange(len(chunked_input_ids), dtype=np.int32)
        len_position_ids = len(position_ids)
        lt = 0
        rt = random.randint(lt, self._scaled_max_length - len_position_ids)
        position_ids[: rt1 - lt1] += lt
        position_ids[rt1 - lt1 :] += rt
        return position_ids

    def pop(self) -> dict:
        indexes = defaultdict(list)
        lengths: list[int] = []

        # print(f"[UnpadBatchedMixedDataset.pop] {self.batch_size = }, {self.max_total_length = }")
        # inputs = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)
        # targets = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=IGNORE_TGT)
        # task_ids = torch.full((self.batch_size, self.max_total_length), dtype=torch.int32, fill_value=-1)
        # position_ids = torch.zeros((self.batch_size, self.max_total_length), dtype=torch.int32)

        inputs = torch.zeros((self.max_total_length,), dtype=torch.int32)
        targets = torch.full((self.max_total_length,), dtype=torch.int32, fill_value=IGNORE_TGT)
        task_ids = torch.full((self.max_total_length,), dtype=torch.int32, fill_value=-1)
        position_ids = torch.zeros((self.max_total_length,), dtype=torch.int32)

        span_begin = 0
        for samples in self.task2seqs.values():
            while samples:
                sample = samples.pop(0)
                # print("pop, sample", sample)
                # exit()
                span_end = span_begin + len(sample["input"])
                # inputs[0, span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
                # targets[0, span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
                # task_ids[0, span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)
                inputs[span_begin:span_end] = torch.tensor(sample["input"], dtype=torch.int32)
                targets[span_begin:span_end] = torch.tensor(sample["target"], dtype=torch.int32)
                task_ids[span_begin:span_end] = torch.tensor(sample["task_id"], dtype=torch.int32)

                if not sample["is_long"] and self._pose_prob > 0.0 and random.uniform(0, 1) < self._pose_prob:
                    _span_position_ids = self._pose_preprocess(sample["input"])
                else:
                    _span_position_ids = np.arange(len(sample["input"]), dtype=np.int32)
                # position_ids[0, span_begin:span_end] = torch.from_numpy(_span_position_ids)
                position_ids[span_begin:span_end] = torch.from_numpy(_span_position_ids)
                # position_ids[0, span_begin:span_end] = torch.arange(len(sample["input"]), dtype=torch.int32)
                lengths.append(len(sample["target"]))
                indexes[int(sample["task_id"])].append(sample["index"])
                self.total_length -= len(sample["target"])
                span_begin = span_end

        cu_seqlens = torch.cat(
            [torch.tensor([0] + lengths).cumsum(dim=-1), torch.tensor([self.max_total_length], dtype=torch.int32)],
            dim=0,
        ).int()

        if self.repeat_data:
            inputs = inputs.repeat(2)
            targets = targets.repeat(2)
            task_ids = task_ids.repeat(2)
            position_ids = position_ids.repeat(2)

        batch = {
            "input_ids": inputs,
            "labels": targets,
            "task_ids": task_ids,
            "indexes": indexes,
            # adhere to flash attention interface
            "cu_seqlens": cu_seqlens,
            "max_seqlen": torch.max(cu_seqlens[1:] - cu_seqlens[:-1]),
            "seq_len": torch.tensor(sum(lengths)).int(),
            "task_names": self.mixed_dataset.names,
            "position_ids": position_ids,
        }
        # print(f"########## {sum(lengths)}")
        # print(f"{inputs.shape = }")
        # print(f"{targets.shape = }")
        # print(f"{task_ids.shape = }")
        # # print(f"{cu_seqlens.shape = }")
        # print(f"{position_ids.shape = }")
        return batch

    def will_be_full(self, sample: dict) -> bool:
        return self.total_length + len(sample["target"]) > self.max_total_length

    def __iter__(self):
        for sample in self.mixed_dataset:
            if self.will_be_full(sample):
                yield self.pop()
            self.put(sample)
