# Source: https://github.com/princeton-nlp/ProLong/blob/main/training/dataset.py

import os
import torch

from streaming import StreamingDataset, Stream
import logging

from itertools import islice

from typing import Dict, Any, List, Tuple
from collections.abc import Iterator


class SafeStream(Stream):
    """Safe if multiple processes try to decompress the same shard."""

    def _decompress_shard_part(self, zip_info, zip_filename, raw_filename, compression):
        unique_extension = "." + str(os.getenv("SLURM_JOB_ID", "local")) + "-" + str(os.getpid())
        super()._decompress_shard_part(zip_info, zip_filename, raw_filename + unique_extension, compression)
        os.rename(raw_filename + unique_extension, raw_filename)


class SortByLengthDataset(StreamingDataset):
    def __init__(
        self,
        *args,
        sort_by_length_size=1,
        single_seq: bool = False,
        per_device_max_tokens: int = 4294967296,
        apply_instruct_masks: bool = False,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.sort_by_length_size = sort_by_length_size
        self.single_seq = single_seq
        self.per_device_max_tokens = per_device_max_tokens
        self.apply_instruct_masks = apply_instruct_masks

    def _negative_item_cost(self, item):
        if "indices" in item:
            return -sum(
                (end - start)**2 for start, end in item["indices"]
            )
        elif "length" in item:
            return -item["length"]**2
        else:
            return -len(item["input_ids"])**2

    def __iter__(self) -> Iterator[Dict[str, Any]]:
        if self.sort_by_length_size <= 1:
            yield from super().__iter__()
        else:
            iterator = super().__iter__()
            while True:
                block = list(islice(iterator, self.sort_by_length_size))
                if not block:
                    return

                yield from sorted(block, key=self._negative_item_cost)

class DataCollator:
    def __init__(
        self, 
        tokenizer,
        single_seg: bool = False,
        per_device_max_tokens: int = 4294967296,
        apply_instruct_masks: bool = False,
    ):
        self.tokenizer = tokenizer
        self.single_seg = single_seg
        self.per_device_max_tokens = per_device_max_tokens
        self.apply_instruct_masks = apply_instruct_masks

    @torch.no_grad()
    def __call__(self, features):
        input_ids = []
        labels = []
        seq_lengths = []

        available_tokens = self.per_device_max_tokens
        for item in features:
            apply_instruct_masks = self.apply_instruct_masks and ("mask" in item)
            indices = item["indices"] if "indices" in item else [(0, len(item["input_ids"]))]
            if self.single_seq:
                indices = [(0, len(item["input_ids"]))]

            label_seq = torch.tensor(item["input_ids"], dtype=torch.long)

            for a, b in indices:
                b = a + min(b - a, available_tokens)
                if b - a > 1:
                    input_seq = torch.tensor(item["input_ids"][a:b], dtype=torch.long)
                    input_ids.append(input_seq)

                    _label = label_seq[a:b]
                    _label[0] = -100 # Don't predict the first token
                    if apply_instruct_masks:
                        # Read the `mask` field and set the corresponding labels to -100
                        mask = torch.tensor(item["mask"][a:b], dtype=torch.long)
                        _label[mask == 0] = -100
                    labels.append(_label)

                    seq_lengths.append(b - a)
                    available_tokens -= b - a
                elif available_tokens <= 0:
                    assert available_tokens == 0, "Available tokens should be non-negative"
                    break

        input_ids = torch.concat(input_ids, dim=0)
        labels = torch.concat(labels, dim=0)
        seq_lengths = torch.tensor(seq_lengths, dtype=torch.long)

        return dict(input_ids=input_ids,
                    attention_mask=None,
                    labels=labels,
                    seq_lengths=seq_lengths)
