              
 
                                                      
                                                  

from dataclasses import dataclass
from datetime import datetime
from types import SimpleNamespace
from typing import Dict, Sequence
import copy
import itertools

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from torch.utils.data import IterableDataset as TorchIterableDataset
import datasets
import torch
import transformers

from megatron_datasets.utils import print_rank_0, print_datetime
from megatron_datasets.mega_indexed_jsonl_dataset import MegaIndexedJsonlDataset
from megatron_datasets.mega_indexed_jsonl_dataset import get_epoch_and_line, update_epoch_and_line
from megatron_datasets.indexed_jsonl_pretrain_dataset import PretrainDataset


def tokenize_text(tokenizer, text):
    y = tokenizer(f'{tokenizer.bos_token}{text}{tokenizer.eos_token}', add_special_tokens=False)
    input_ids = y.input_ids
    labels = copy.deepcopy(input_ids)

    input_ids = input_ids[:-1]
    labels = labels[1:]

    attention_mask = y.attention_mask
    return input_ids, labels, attention_mask


@dataclass
class PackedPretrainDataset(PretrainDataset):
          
                               
                                   
                                     

    def __init__(
        self,
        tokenizer,
        max_seq_len,
        max_position_embeddings,
        path_likes,
        domain_probabilities,
        domain_names,
        train_data_consuming_progresses=None,
        train=False,
        rank=0,
        dp_rank=0,
        dp_size=1,
        access_policy_interleave=False,
        shuffle_buffer_size=1000,
        eval_samples=None,
        seed=0,
        retention_rates_per_domains=[],
        enable_pareto=[],
        pareto_alphas=[],
        pareto_scales=[],
        pareto_score_scales=[],
    ):
        super().__init__(
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            path_likes=path_likes,
            domain_probabilities=domain_probabilities,
            domain_names=domain_names,
            train_data_consuming_progresses=train_data_consuming_progresses,
            train=train,
            rank=rank,
            dp_rank=dp_rank,
            dp_size=dp_size,
            access_policy_interleave=access_policy_interleave,
            shuffle_buffer_size=shuffle_buffer_size,
            eval_samples=eval_samples,
            seed=seed,
            retention_rates_per_domains=retention_rates_per_domains,
            enable_pareto=enable_pareto,
            pareto_alphas=pareto_alphas,
            pareto_scales=pareto_scales,
            pareto_score_scales=pareto_score_scales,
        )
        self.max_position_embeddings = max_position_embeddings

    def iter_in_epoch(self, epoch, bufs):
        if torch.distributed.get_rank() == 0:
            print(f'PretrainDataset.iter_in_epoch epoch {epoch} train {self.train}')

                                                        
        for domain_id, _ in enumerate(self.domain_names):
            bufs[domain_id].n_packed = 0

        for example in self.underlying:
            domain_id = example['domain_id']
            bufs[domain_id].n_packed += 1
            if example.get('deleted', False):             
                continue

            text = example['content']                  
            _input_ids, _labels, _ = tokenize_text(self.tokenizer, text)

            while len(_input_ids):
                                                                   
                new_len = min(
                    min(len(_input_ids), self.max_position_embeddings),
                    self.max_seq_len - len(bufs[domain_id].input_ids)
                )

                                      
                bufs[domain_id].input_ids += _input_ids[:new_len]
                bufs[domain_id].labels += _labels[:new_len]
                bufs[domain_id].cu_seqlens.append(bufs[domain_id].cu_seqlens[-1] + new_len)

                                  
                _input_ids = _input_ids[new_len:]
                _labels = _labels[new_len:]

                assert len(bufs[domain_id].input_ids) <= self.max_seq_len
                if len(bufs[domain_id].input_ids) == self.max_seq_len:
                    input_ids = bufs[domain_id].input_ids
                    labels = bufs[domain_id].labels
                    ret_d = {
                        'input_ids': torch.tensor(input_ids, dtype=torch.int64),
                        'labels': torch.tensor(labels, dtype=torch.int64),
                        'train': self.train,
                        'epoch': epoch,
                        'line': bufs[domain_id].n_packed,
                        'cu_seqlens': torch.tensor(bufs[domain_id].cu_seqlens, dtype=torch.int64)
                    }
                    yield ret_d
                    bufs[domain_id].input_ids = []
                    bufs[domain_id].labels = []
                    bufs[domain_id].cu_seqlens = [0]
                    bufs[domain_id].n_packed = 0

                    if not self.train:
                        self.eval_yielded += 1
                        if self.eval_yielded >= self.eval_samples:                               
                            self.eval_yielded = 0
                            return

    def __iter__(self):
        assert not self.in_iter
        self.in_iter = True
        self.eval_yielded = 0
        bufs = [
            SimpleNamespace(input_ids=[], labels=[], cu_seqlens=[0], n_packed=0)
            for domain_id, _ in enumerate(self.domain_names)
        ]
        for epoch in itertools.count(start=self.start_epoch):
            yield from self.iter_in_epoch(epoch, bufs)

            self.underlying = MegaIndexedJsonlDataset(
                self.path_likes,
                self.domain_probabilities,
                self.domain_names,
                dp_rank=self.dp_rank,
                dp_size=self.dp_size,
                epoch=epoch + 1,
                consumed=0,
                shuffle_buffer_size=self.shuffle_buffer_size,
                seed=self.seed,
                train=self.train,
                retention_rates_per_domains=self.retention_rates_per_domains,
                enable_pareto=self.enable_pareto,
                pareto_alphas=self.pareto_alphas,
                pareto_scales=self.pareto_scales,
                pareto_score_scales=self.pareto_score_scales,
            )
        assert False, 'never reachable'


def build_train_valid_test_datasets(args, tokenizer, rank=0, dp_rank=0, dp_size=1):
    train_path_likes = args.data_path
    eval_path_likes = args.px_eval_data_path
    domain_probabilities = args.px_domain_probabilities
    retention_rates_per_domains = args.px_retention_rates_per_domain
    enable_pareto = args.px_train_apply_pareto
    pareto_alpha = args.px_train_pareto_alpha
    pareto_scale = args.px_train_pareto_scale
    pareto_score_scale = args.train_pareto_score_scale
    domain_names = args.px_train_data_domain_names
    assert args.num_workers <= 1

    print_rank_0(
        f'build_train_valid_datasets train_data_consuming_progresses {args.train_data_consuming_progresses}'
    )
    train_ds = PackedPretrainDataset(
        tokenizer,
        args.seq_length,
        args.max_position_embeddings,
        train_path_likes,
        domain_probabilities,
        domain_names,
        train_data_consuming_progresses=args.train_data_consuming_progresses,
        train=True,
        rank=rank,
        dp_rank=dp_rank,
        dp_size=dp_size,
        access_policy_interleave=args.px_indexed_jsonl_dataset_access_policy_interleave,
        shuffle_buffer_size=args.px_shuffle_buffer_size,
        seed=args.seed,
        retention_rates_per_domains=retention_rates_per_domains,
        enable_pareto=enable_pareto,
        pareto_alphas=pareto_alpha,
        pareto_scales=pareto_scale,
        pareto_score_scales=pareto_score_scale,
    )
    eval_ds = None
    if eval_path_likes is not None:
        eval_samples = args.eval_iters * args.global_batch_size // dp_size
        eval_ds = PackedPretrainDataset(
            tokenizer,
            args.seq_length,
            args.max_position_embeddings,
            eval_path_likes,
            None,
            args.px_eval_data_domain_names,
            train_data_consuming_progresses=None,
            train=False,
            rank=rank,
            dp_rank=dp_rank,
            dp_size=dp_size,
            access_policy_interleave=args.px_indexed_jsonl_dataset_access_policy_interleave,
            shuffle_buffer_size=0,
            eval_samples=eval_samples,
            seed=0
        )
    test_ds = None
    return train_ds, eval_ds, test_ds
