"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import random
import itertools
from typing import Dict, Optional, Sequence
import copy

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, load_from_disk
from torch.nn.utils.rnn import pad_sequence


def extract_answer(text):
    split_pattern = '####'
    if split_pattern not in text:
        return text.strip().replace(',', '')
    else:
        _, ans = text.strip().split('####', 1)
        ans = '####' + ans
        ans = ans.strip().replace(',', '')
        return ans

def extract_cot(text):
    split_pattern = '####'
    if split_pattern not in text:
        return None
    else:
        cot, _ = text.strip().split('####', 1)
        cot = cot.strip()
        return cot

class GSM8kDataset(Dataset):
    def __init__(self, tokenizer, file_path, max_length):
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"
        print (f'Creating features from dataset file at {file_path}')
        bos_tok = tokenizer.bos_token
        eos_tok = tokenizer.eos_token

        with open(file_path, encoding="utf-8") as f:
            lines = [line.split('||') for line in f.read().splitlines() if (len(line) > 0 and not line.isspace()
                                                                             and len(line.split('||')) ==2 )]
        src_lines, tgt_lines = list(zip(*lines))
        src_lines = list(src_lines)
        tgt_lines = list(tgt_lines)

        edited_sents_cot = []
        edited_sents_only = []
        edited_sents_all = []
        edited_sents_nocot = []
        for src, tgt in zip(src_lines, tgt_lines):
            # import pdb; pdb.set_trace()
            ans = extract_answer(tgt)
            cot = extract_cot(tgt)
            cot = ' ' if cot is None else cot
            # try:
            sent = ' {} '.format(src) + tgt

            edited_sents_all.append(sent)
 
        batch_encoding_all = tokenizer(edited_sents_all, add_special_tokens=True, truncation=True, max_length=max_length)
        batch_encoding_src = tokenizer(src_lines, add_special_tokens=True, truncation=True, max_length=max_length)
        
        self.examples_all = batch_encoding_all["input_ids"]
        self.input_all = batch_encoding_src["input_ids"]
        self.labels_all = copy.deepcopy(self.examples_all)
 
        self.src_sent_cot = []
        self.tgt_sent_cot = []

        temp_src_len = 0
        temp_tgt_len = 0
        temp_count = 0
        separator = tokenizer.eos_token_id #tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
        for i, elem in enumerate(self.labels_all):
            sep_idx = len(batch_encoding_src['input_ids'][i])
            self.labels_all[i][:sep_idx] = [-100] * sep_idx
 
    def __len__(self):
        return len(self.examples_all)

    # def __getitem__(self, i) -> torch.Tensor:
    def __getitem__(self, i):
        return (# torch.tensor(self.examples_cot[i], dtype=torch.long),
                torch.tensor(self.input_all[i], dtype=torch.long),
                torch.tensor(self.examples_all[i], dtype=torch.long),
                torch.tensor(self.labels_all[i], dtype=torch.long),
                )

class GSM8kDataCollator:
    """
    VAEData collator used for language modeling.
    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, examples):
        input_ids, input_ids_all, labels_all = zip(*examples)

        input_ids_all = self._tensorize_batch(input_ids_all)
        input_ids_all[input_ids_all.lt(0)] = self.tokenizer.eos_token_id
        labels_all = self._tensorize_batch(labels_all)
        return {"input_ids": input_ids_all, "labels": labels_all} # , "labels_cot_shift": labels_cot_shift, "labels_nocot": labels_nocot, 'input_ids_only': input_ids_only, 'input_ids_all': input_ids_all, 'labels_all': labels_all}

    def _tensorize_batch(self, examples):
        # In order to accept both lists of lists and lists of Tensors
        if isinstance(examples[0], (list, tuple)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]
        length_of_first = examples[0].size(0)
        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length:
            return torch.stack(examples, dim=0)
        else:
            return pad_sequence(examples, batch_first=True, padding_value=-100)



IGNORE_INDEX = -100
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

class AlpacaDataCollator:
    """
    VAEData collator used for language modeling.
    - collates batches of tensors, honoring their tokenizer's pad_token
    - preprocesses batches for masked language modeling
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer


    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.eos_token_id),
        )

    def _tensorize_batch(self, examples):
        # In order to accept both lists of lists and lists of Tensors
        if isinstance(examples[0], (list, tuple)):
            examples = [torch.tensor(e, dtype=torch.long) for e in examples]
        length_of_first = examples[0].size(0)
        are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
        if are_tensors_same_length:
            return torch.stack(examples, dim=0)
        else:
            return pad_sequence(examples, batch_first=True, padding_value=-100)


class AlpacaDataset(Dataset):

    def __init__(self, data_dir: str, tokenizer, block_size=1024, split="train"):
        super(AlpacaDataset, self).__init__()
        # logging.warning("Loading data...")
        self.data = load_dataset(data_dir)

        # logging.warning("Formatting inputs...")
        
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
        sources = [
            prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
            for example in self.data['train']
        ]
        targets = [f"{example['output']}{tokenizer.eos_token}" for example in self.data['train']]

        # logging.warning("Tokenizing inputs... This may take some time...")
        data_dict = self.preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])

    def select(self, indices):
        new_dataset = AlpacaDataset(data_dir=os.path.dirname(self.data.cache_files[0]['filename'].replace(self.split, '')),
                                      block_size=self.block_size,
                                      split=self.split)
        # 选择指定索引的数据
        new_dataset.data = self.data.select(indices)
        # 更新数据长度
        new_dataset.data_len = len(new_dataset.data) // new_dataset.w_size
        return new_dataset

    def _tokenize_fn(self, strings: Sequence[str], tokenizer) -> Dict:
        """Tokenize a list of strings."""
        # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenized_list = [
            tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=tokenizer.model_max_length,
                truncation=True,
            )
            for text in strings
        ]
        input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
        input_ids_lens = labels_lens = [
            tokenized.input_ids.ne(tokenizer.eos_token_id).sum().item() for tokenized in tokenized_list
        ]
        
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def preprocess(
        self, 
        sources: Sequence[str],
        targets: Sequence[str],
        tokenizer,
    ) -> Dict:
        """Preprocess the data by tokenizing."""
        examples = [s + t for s, t in zip(sources, targets)]
        examples_tokenized, sources_tokenized = [self._tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
        input_ids = examples_tokenized["input_ids"]
        labels = copy.deepcopy(input_ids)
        for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
            label[:source_len] = IGNORE_INDEX
        # import pdb; pdb.set_trace()
        return dict(input_ids=input_ids, labels=labels)


class ArrowNCPDataset(Dataset):
    def __init__(self, data_dir, block_size=16384, split="train"):
        self.data_dir = data_dir
        self.data = load_from_disk(os.path.join(data_dir, split))
        self.w_size = block_size // 8192
        self.block_size_nums = block_size // 2048
        self.block_size = block_size
        self.split = split
        self.data_len = len(self.data) // self.w_size // self.block_size_nums

    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        # attention_mask = torch.ones_like(x)
        x = torch.from_numpy(np.array(list(itertools.chain(*self.data[idx*self.block_size_nums:(idx+self.w_size)*self.block_size_nums]['input_ids']))).reshape(-1).astype(np.int64))
        attention_mask = torch.from_numpy(np.array(list(itertools.chain(*self.data[idx*self.block_size_nums:(idx+self.w_size)*self.block_size_nums]['attention_mask']))).reshape(-1).astype(np.int64))
        return {"input_ids": x, "labels": x, "attention_mask": attention_mask}

    def select(self, indices):
        new_dataset = ArrowNCPDataset(data_dir=os.path.dirname(self.data.cache_files[0]['filename'].replace(self.split, '')),
                                      block_size=self.block_size,
                                      split=self.split)
        # 选择指定索引的数据
        new_dataset.data = self.data.select(indices)
        # 更新数据长度
        new_dataset.data_len = len(new_dataset.data) // new_dataset.w_size
        return new_dataset

class ArrowDataset(Dataset):
    def __init__(self, data_dir, block_size=1024, split="train"):
        self.data = load_from_disk(os.path.join(data_dir, split))
        self.block_size = block_size
        self.block_size_nums = block_size // 2048
        self.split = split
        self.data_len = len(self.data) // self.block_size_nums
    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        # attention_mask = torch.ones_like(x)
        x = torch.from_numpy(np.array(self.data[idx*self.block_size_nums:(idx+1)*self.block_size_nums]['input_ids']).reshape(-1).astype(np.int64))
        attention_mask = torch.from_numpy(np.array(self.data[idx*self.block_size_nums:(idx+1)*self.block_size_nums]['attention_mask']).reshape(-1).astype(np.int64))
        return {"input_ids": x, "labels": x, "attention_mask": attention_mask}
    def select(self, indices):
        new_dataset = ArrowDataset(data_dir=os.path.dirname(self.data.cache_files[0]['filename'].replace(self.split, '')),
                                      block_size=self.block_size,
                                      split=self.split)
        # 选择指定索引的数据
        new_dataset.data = self.data.select(indices)
        new_dataset.data_len = len(new_dataset.data)
        # 更新数据长度
        return new_dataset

class PreprocessedDataset(Dataset):
    def __init__(self, data_dir, block_size=1024, split="train", task='contextlm'):
        self.data = np.memmap(
            os.path.join(data_dir, f"{split}.bin"), dtype=np.uint16, mode="r"
        )
        self.block_size = block_size
        self.split = split

        self.data_len = len(self.data) // self.block_size  # drop last block
        self.remain = len(self.data) % self.block_size
        
    def __len__(self):
        return self.data_len

    def __getitem__(self, idx):
        if self.split == "train":
            if idx < self.data_len - 1:
                random_shift = random.randint(0, self.block_size)
            else:
                random_shift = random.randint(0, self.remain)
            x = torch.from_numpy(
                (
                    self.data[
                        idx * self.block_size
                        + random_shift : (idx + 1) * self.block_size
                        + random_shift
                    ]
                ).astype(np.int64)
            )
        else:
            x = torch.from_numpy(
                (self.data[idx * self.block_size : (idx + 1) * self.block_size]).astype(
                    np.int64
                )
            )
        attention_mask = torch.ones_like(x)
        return {"input_ids": x, "labels": x, "attention_mask": attention_mask}

    def select(self, indices):
        class SubsetDataset(Dataset):
            def __init__(self, dataset, indices):
                self.dataset = dataset
                self.indices = indices

            def __len__(self):
                return len(self.indices)

            def __getitem__(self, idx):
                return self.dataset[self.indices[idx]]

        return SubsetDataset(self, indices)

def get_train_dataloader(cfg):
    kwargs = {}
    if cfg.dataset == "openwebtext":
        train_dataset = PreprocessedDataset(
            cfg.data_dir, block_size=cfg.block_size, split="train", task=cfg.mode,
        )
    elif cfg.dataset == "pile":
        if cfg.mode == 'ncp':
            train_dataset = ArrowNCPDataset(cfg.data_dir, block_size=cfg.block_size, split="train"
            )
        else:
            train_dataset = ArrowDataset(
            cfg.data_dir, block_size=cfg.block_size, split="train"
            )
    elif cfg.dataset == 'fineweb':
        train_dataset = PreprocessedDataset(cfg.data_dir, block_size=cfg.block_size, split="train")
    elif cfg.dataset == 'alpaca':
        tokenizer = AutoTokenizer.from_pretrained(cfg.feature_extractor_model)
        train_dataset = AlpacaDataset(cfg.data_dir, tokenizer, split="train")
    else:
        print(f"dataset [{cfg.dataset}] not supported for evaluation")
        raise NotImplementedError

    batch_size = cfg.update_batch_size // cfg.world_size
    cfg.n_epochs = (
        cfg.train_steps
        * cfg.update_batch_size
        * cfg.grad_acc_steps
        // len(train_dataset)
        + 1
    )
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=cfg.num_workers,
        **kwargs,
    )
    return train_dataloader

def get_val_dataloaders(cfg):
    if cfg.val_datasets is None:
        return None

    val_dataloaders = {}
    for val_name in cfg.val_datasets:

        kwargs = {}
        if val_name == "openwebtext":
            val_dataset = PreprocessedDataset(
                cfg.data_dir, block_size=cfg.block_size, split="val"
            )
        elif val_name == 'pile':
            if cfg.mode == 'ncp':
                val_dataset = ArrowNCPDataset(cfg.data_dir, block_size=cfg.block_size, split="validation")
            else:
                val_dataset = ArrowDataset(
                cfg.data_dir, block_size=cfg.block_size, split="validation"
                )
        elif val_name == 'fineweb':
            val_dataset = PreprocessedDataset(cfg.data_dir, block_size=cfg.block_size, split="val")
        
        else:
            continue

        batch_size = cfg.batch_size_eval // cfg.world_size
        val_dataloader = DataLoader(
            val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, **kwargs
        )
        val_dataloaders[val_name] = val_dataloader
    return val_dataloaders
