# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import json
from tqdm import tqdm

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset


IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"

def get_generated_dataset(path, tokenizer, max_len, batch_size=8, return_loader=False):
    with open(path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]

    dataset = GeneratedDataset(
        data, 
        tokenizer, 
        block_size=max_len
    )

    if not return_loader:
        return dataset
    
    loader=DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader 


def get_wiki_dataset(tokenizer, max_len, batch_size=4, split='validation', return_loader=False):
    data = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)

    dataset = WikiDataset(data,
                          tokenizer,
                          max_len)
    
    if not return_loader:
        return dataset
    
    dataloader = DataLoader(dataset, 
                            batch_size=batch_size, 
                            shuffle=False)
    return dataloader

class GeneratedDataset(Dataset):
    def __init__(self, dataset, tokenizer, block_size=2048):
        self.tokenizer = tokenizer
        self.block_size = block_size
        
        tokenized_data = []
        for d in tqdm(dataset, desc="Tokenizing", unit="item"):
            tokenized_data.append(self.tokenize_function(d))
        
        grouped_data = self.group_texts(tokenized_data)
        
        self.input_ids = grouped_data["input_ids"]
        self.attention_mask = grouped_data.get("attention_mask", None)
        self.labels = grouped_data["labels"]
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        item = {
            "input_ids": torch.tensor(self.input_ids[idx]),
            "labels": torch.tensor(self.labels[idx])
        }
        if self.attention_mask:
            item["attention_mask"] = torch.tensor(self.attention_mask[idx])
        return item
    
    def tokenize_function(self, example):
        """Tokenize a single example."""
        return self.tokenizer(
            example["text"],
            truncation=True,
            max_length=self.block_size,
            return_tensors="pt"
        )
    
    def group_texts(self, examples):
        """Concatenate and split into chunks."""
        concatenated = {k: [] for k in examples[0].keys()}
        
        for ex in examples:
            for k, v in ex.items():
                concatenated[k].append(v.squeeze(0))  
        
        result = {}
        for k in concatenated.keys():
            stacked = torch.cat(concatenated[k])
            result[k] = stacked.tolist()  
        
        total_length = len(result["input_ids"])
        if total_length >= self.block_size:
            total_length = (total_length // self.block_size) * self.block_size
        
        chunked = {
            k: [
                v[i:i + self.block_size]
                for i in range(0, total_length, self.block_size)
            ]
            for k, v in result.items()
        }
        
        chunked["labels"] = chunked["input_ids"].copy()
        
        return chunked



class WikiDataset(Dataset):
    def __init__(self, data, tokenizer, max_len=2048):
        self.tokenizer = tokenizer
        self.data = data
        self.max_len = max_len

        self.data = self.tokenizer("\n\n".join(data["text"]), return_tensors="pt")
        self.data = self.create_input_target_pairs()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_ids, target_ids = self.data[idx]
        return input_ids.squeeze(0), target_ids.squeeze(0)

    def create_input_target_pairs(self):
        seq_len = self.data.input_ids.size(1)
        input_target_pairs = []
        prev_end_loc = 0

        for begin_loc in range(0, seq_len, self.max_len):
            end_loc = min(begin_loc + self.max_len, seq_len)
            
            input_ids = self.data.input_ids[:, begin_loc:end_loc]
            target_ids = input_ids.clone()
            
            actual_len = input_ids.size(1)
            
            pad_len = self.max_len - actual_len
            if pad_len > 0:
                pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
                input_ids = torch.nn.functional.pad(input_ids, (0, pad_len), value=pad_token_id)
                target_ids = torch.nn.functional.pad(target_ids, (0, pad_len), value=-100)

            if begin_loc == 0:
                trg_len = actual_len
            else:
                trg_len = end_loc - prev_end_loc
                trg_len = min(trg_len, actual_len)
            
            if trg_len < actual_len:
                target_ids[:, :actual_len - trg_len] = -100
            
            input_target_pairs.append((input_ids, target_ids))
            prev_end_loc = end_loc

        return input_target_pairs
