# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Dataloaders."""


import random
import torch
import numpy as np
from torch.utils.data import Dataset
from megatron.training import get_args
from megatron.core import mpu
from torch.utils.data._utils.collate import default_collate
from megatron_patch.tokenizer import get_tokenizer
from dataclasses import dataclass

@dataclass
class SFTCollator:
    args = None
    tokenizer = None
    def collect_preprocess_data(self, samples):
        sample_list = []
        max_len = self.args.max_padding_length
        for item in samples:
            # print(item, flush=True)
            position_ids = np.arange(max_len)
            attention_mask = np.zeros(
                (max_len, max_len)
            )
            for sidx in item["sample_idx"]:
                attention_mask[sidx[0] : sidx[1], sidx[0] : sidx[1]] = np.tril(
                    np.ones((sidx[1] - sidx[0], sidx[1] - sidx[0]))
                )[: max_len, : max_len]
                position_ids[sidx[0] : sidx[1]] = np.arange(sidx[1] - sidx[0])[
                    : max_len
                ]
            attention_mask.dtype = "int64"
            position_ids.dtype = "int64"
            # print(item["input_ids"], flush=True)
            item["input_ids"] = item["input_ids"] + [self.tokenizer.eos_token_id] * max_len
            
            # item["input_ids"] = item["input_ids"][:max_len]
            item["loss_mask"] = item["loss_mask"] + [0] * max_len
            # item["loss_mask"] = item["loss_mask"][:max_len]
            if sum(item["loss_mask"][: max_len - 1]) == 0: # 防止截断后mask全为0，导致loss=nan
                item["loss_mask"][max_len - 2] = 1
            # item["text"] = item["text"][: max_len]
            # item["text"][-1] = self.tokenizer.eod_id # 最后一位强行设置成eos
            #print(f'self.tokenizer.eos_token_id is:{[self.tokenizer.pad_token_id] * 3}', flush=True)
            # print(f'item["input_ids"][: max_len] + [self.tokenizer.pad_token_id] is:{item["input_ids"][: max_len]}', flush=True)
            sample = {
                "input_ids": np.array(
                    item["input_ids"][: max_len] + [self.tokenizer.pad_token_id], dtype=np.int64
                ),
                "loss_mask": np.array(
                    item["loss_mask"][: max_len], dtype=np.float64
                ),
                "attention_mask": attention_mask[
                    : max_len, : max_len
                ],
                "position_ids": position_ids[: max_len],
            }
            sample_list.append(sample)
        return sample_list

    def __call__(self, samples):
        
        sample_list = self.collect_preprocess_data(samples)
        return default_collate(sample_list)
def build_pretraining_data_loader(dataset, consumed_samples):
    """Build dataloader given an input dataset."""

    if dataset is None:
        return None
    args = get_args()

    # Megatron sampler
    if args.dataloader_type == 'single':
        batch_sampler = MegatronPretrainingSampler(
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
            data_parallel_size=mpu.get_data_parallel_world_size())
    elif args.dataloader_type == 'cyclic':
        batch_sampler = MegatronPretrainingRandomSampler(
            dataset,
            total_samples=len(dataset),
            consumed_samples=consumed_samples,
            micro_batch_size=args.micro_batch_size,
            data_parallel_rank=mpu.get_data_parallel_rank(),
            data_parallel_size=mpu.get_data_parallel_world_size(),
            data_sharding=args.data_sharding)
    elif args.dataloader_type == "external":
        # External dataloaders are passed through. User is expected to provide a
        # torch-compatible dataloader and define samplers, if needed.
        return dataset
    else:
        raise Exception('{} dataloader type is not supported.'.format(
                args.dataloader_type))
    
    # Torch dataloader.
    if args.train_mode == 'finetune' and ("-Raw" in args.dataset):
        collator = SFTCollator()
        collator.args = args
        tokenizer = get_tokenizer()
        collator.tokenizer = tokenizer
        return torch.utils.data.DataLoader(dataset,
                                        batch_sampler=batch_sampler,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        persistent_workers=True if args.num_workers > 0 else False,
                                        collate_fn=collator,
                                        )
    else:

        return torch.utils.data.DataLoader(dataset,
                                        batch_sampler=batch_sampler,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        persistent_workers=True if args.num_workers > 0 else False,
                                        )

class MegatronPretrainingSampler:

    def __init__(self, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size, drop_last=True):
        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.drop_last = drop_last

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.consumed_samples < self.total_samples, \
            'no samples left to consume: {}, {}'.format(self.consumed_samples,
                                                        self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)

    def __len__(self):
        return self.total_samples

    def get_start_end_idx(self):
        start_idx = self.data_parallel_rank * self.micro_batch_size
        end_idx = start_idx + self.micro_batch_size
        return start_idx, end_idx

    def __iter__(self):
        batch = []
        # Last batch will be dropped if drop_last is not set False
        for idx in range(self.consumed_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self.micro_batch_times_data_parallel_size:
                start_idx, end_idx = self.get_start_end_idx()
                yield batch[start_idx:end_idx]
                batch = []

        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            start_idx, end_idx = self.get_start_end_idx()
            yield batch[start_idx:end_idx]


class RandomSeedDataset(Dataset):

    def __init__(self, dataset):
        args = get_args()
        self.base_seed = args.seed
        self.curr_seed = args.seed
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def set_epoch(self, epoch):
        self.curr_seed = self.base_seed + epoch

    def __getitem__(self, idx):
        seed = idx + self.curr_seed
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        return self.dataset[idx]


class MegatronPretrainingRandomSampler:

    def __init__(self, dataset, total_samples, consumed_samples, micro_batch_size,
                 data_parallel_rank, data_parallel_size, data_sharding):
        # Keep a copy of input params for later use.
        self.dataset = dataset
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.data_parallel_size = data_parallel_size
        self.data_sharding = data_sharding
        self.micro_batch_times_data_parallel_size = \
            self.micro_batch_size * data_parallel_size
        self.last_batch_size = \
            self.total_samples % self.micro_batch_times_data_parallel_size

        # Sanity checks.
        assert self.total_samples > 0, \
            'no sample to consume: {}'.format(self.total_samples)
        assert self.micro_batch_size > 0
        assert data_parallel_size > 0
        assert self.data_parallel_rank < data_parallel_size, \
            'data_parallel_rank should be smaller than data size: {}, ' \
            '{}'.format(self.data_parallel_rank, data_parallel_size)

    def __len__(self):
        return self.total_samples

    def __iter__(self):
        active_total_samples = self.total_samples - self.last_batch_size
        self.epoch = self.consumed_samples // active_total_samples
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        if isinstance(self.dataset, RandomSeedDataset):
            self.dataset.set_epoch(self.epoch)

        # data sharding and random sampling
        if self.data_sharding:
            bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
                           * self.micro_batch_size
            bucket_offset = current_epoch_samples // self.data_parallel_size
            start_idx = self.data_parallel_rank * bucket_size

            g = torch.Generator()
            g.manual_seed(self.epoch)
            random_idx = torch.randperm(bucket_size, generator=g).tolist()
            idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
        else:
            full_bucket_size = (self.total_samples // self.micro_batch_size) \
                                * self.micro_batch_size
            full_bucket_offset = current_epoch_samples
            g = torch.Generator()
            g.manual_seed(self.epoch)
            idx_range_total = \
                torch.randperm(full_bucket_size, generator=g).tolist()
            idx_range_active = idx_range_total[full_bucket_offset:]
            idx_range = idx_range_active[self.data_parallel_rank::self.data_parallel_size]

        batch = []
        # Last batch if not complete will be dropped.
        for idx in idx_range:
            batch.append(idx)
            if len(batch) == self.micro_batch_size:
                self.consumed_samples += self.micro_batch_times_data_parallel_size
                yield batch
                batch = []
