import numpy as np
import torch
from torch.utils.data import Dataset
from .binidx import MMapIndexedDataset
import json

import torch.distributed as dist

class MMapIndexedDatasetWithContext(Dataset):
    def __init__(self, args):
        self.args = args
        self.vocab_size = args.vocab_size
        self.data = MMapIndexedDataset(args.data_file)
        self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size

    def __len__(self):
        return self.args.epoch_steps * self.args.micro_bsz

    def __getitem__(self, idx):
        args = self.args
        ctx_len = args.ctx_len
        req_len = ctx_len + 1
        data = self.data

        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
        
        np.random.seed(idx + rank * 1000) 
        
        i = np.random.randint(0, self.data_size - req_len)
        dix = data.get(idx=0, offset=i, length=req_len).astype(int)

        input_ids = torch.tensor(dix[:-1], dtype=torch.long)
        labels = torch.tensor(dix[1:], dtype=torch.long)
        
        # print(f"Rank {rank} processing index {idx}, selected data offset {i} , {input_ids}") 
        return {"input_ids": input_ids, "labels": labels}


class JsonlDataset(Dataset):
    def __init__(self, file_path, tokenizer, ctx_len):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.ctx_len = ctx_len
        self.data = self._load_jsonl(file_path)

    def _load_jsonl(self, file_path):
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line.strip()))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
   
        rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
        example = self.data[idx]
        text = example['text']
        tokens = self.tokenizer.encode(text, truncation=True, max_length=self.ctx_len)
        
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        labels = torch.tensor(tokens[1:], dtype=torch.long)
        
        return {"input_ids": input_ids, "labels": labels}
