from torch.utils.data import Dataset, DataLoader
import math
import torch
import os
from datasets import load_dataset
import tqdm
import math
from torch.utils.data.distributed import DistributedSampler
import gc
import multiprocessing as mp

class ChunkedIterator:
    def __init__(self, iterator, chunk_size, cutoff=None):
        self.iterator = iterator
        self.chunk_size = chunk_size
        self.eof = False
        self.ichunk = 0
        self.cutoff = cutoff
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.eof: raise StopIteration
        
        chunk = []
        end_of_file = False
        while not end_of_file:
            try:
                line = next(self.iterator)['text']
                chunk.append(line)
                if self.ichunk * self.chunk_size > self.cutoff:
                    raise StopIteration
            except StopIteration:
                end_of_file = True
                self.eof = True
            if len(chunk) >= self.chunk_size or end_of_file:
                # print('read', self.ichunk, self.cutoff, self.ichunk * self.chunk_size > self.cutoff)
                self.ichunk += 1
                return chunk

class Wikitext2Dataset(Dataset):
    def __init__(self, subset, tokenizer, stride=2048, max_length=None, strided_indexing=None):
        super().__init__()
        
        self.tokenizer = tokenizer
        if subset == 'valid':
            subset = 'validation'
        if subset in ['validation', 'test'] and strided_indexing is None:
            strided_indexing = True
        self.strided_indexing = strided_indexing
        
        os.makedirs('./cache/wikitext', exist_ok=True)
        dataset = 'wikitext2'
        if os.environ.get('FORCE_OPENWEBTEXT', '0') == '1':
            print('FORCELY USE OPENWEBTEXT!')
            dataset = 'openwebtext'
        cache_path = f'./cache/wikitext/{dataset}-{subset}.pth'
        if os.path.exists(cache_path):
            self.encodings = torch.load(cache_path)
            print('cache size', self.encodings.shape)
        else:
            cutoff_dataset = 5000000 # 5M document
            if dataset == 'openwebtext':
                chunk_size = 50
                if subset == 'train':
                    cutoff_dataset = 500000 # 500k document
                    data = load_dataset("Skylion007/openwebtext", split='train[:99%]')
                else:
                    cutoff_dataset = 2000 # 2k document
                    data = load_dataset("Skylion007/openwebtext", split='train[99%:]')
                print('OPENWEBTEXT loaded')
            else:
                chunk_size = 50*1000
                data = load_dataset("wikitext", "wikitext-2-raw-v1", split=subset)
            os.environ['TOKENIZERS_PARALLELISM'] = 'false'
            
            # self.encodings = tokenizer("\n\n".join(data["text"]), return_tensors="pt").input_ids
            
            # num_lines = len(data['text'])
            # num_chunks = math.ceil(num_lines / chunk_size)
            # encodings = []
            # print('nchunk', num_chunks, flush=True)
            # for ichunk in tqdm.tqdm(range(num_chunks), disable=num_chunks < 2, leave=False, dynamic_ncols=True):
            #     chunk = data['text'][ichunk*chunk_size:min(num_lines, (ichunk+1)*chunk_size)]
            #     print('a', flush=True)
            #     flatten_text = "\n\n".join(chunk)
            #     print('b', flush=True)
            #     if ichunk == 0:
            #         flatten_text = '</s>' + flatten_text
            #     print('c', flush=True)
            #     chunk_encodings = tokenizer(flatten_text, return_tensors='pt', add_special_tokens=True).input_ids
            #     print(chunk_encodings.shape, flush=True)
            #     encodings.append(chunk_encodings)
            #     gc.collect()
            #     print('d', flush=True)
            
            # data_iter = iter(data)
            # chunk = []
            # encodings = []
            # ichunk = 0
            # num_tokens = 0
            # end_of_file = False
            # with tqdm.tqdm(leave=False, dynamic_ncols=True) as pbar:
            #     while not end_of_file:
            #         pbar.update(1)
            #         try:
            #             line = next(data_iter)['text']
            #             chunk.append(line)
            #         except StopIteration:
            #             end_of_file = True
            #         if len(chunk) >= chunk_size or end_of_file:
            #             # print('a', flush=True)
            #             flatten_text = "\n\n".join(chunk)
            #             # print('b', flush=True)
            #             if ichunk == 0:
            #                 flatten_text = '</s>' + flatten_text
            #             # print('c', flush=True)
            #             chunk_encodings = tokenizer(flatten_text, return_tensors='pt', add_special_tokens=False).input_ids
            #             # print(chunk_encodings.shape, flush=True)
            #             encodings.append(chunk_encodings)
            #             num_tokens += chunk_encodings.shape[1]
            #             pbar.set_description(f'Tokens: {num_tokens}')
            #             if (ichunk % 10) == 0:
            #                 gc.collect()
            #             # print('d', flush=True)
            #             ichunk += 1
            
            encodings = []
            encodings_size = 0
            chunked_iterator = ChunkedIterator(iter(data), chunk_size, cutoff=cutoff_dataset)
            with mp.Pool(mp.cpu_count() - 1) as pool, \
                    tqdm.tqdm(pool.imap(self.get_encodings, chunked_iterator, chunksize=8), dynamic_ncols=True, total=math.ceil(cutoff_dataset/chunk_size)) as pbar:
                for chunk_encodings in pbar:
                    encodings_size += chunk_encodings.shape[1]
                    pbar.set_description(f'tokens: {encodings_size}')
                    encodings.append(torch.tensor(chunk_encodings))
            
            self.encodings = torch.cat(encodings, dim=1)
            
            torch.save(self.encodings, cache_path)
        self.seq_len = self.encodings.size(1)
        print(f'{self.seq_len} tokens loaded')
        # self.seq_len = self.encodings.input_ids.size(1)
        self.stride = stride
        self.max_length = max_length
        self.check_last_shape = subset == 'train'
        self.last_shape = None
    
    def get_encodings(self, chunk):
        flatten_text = "\n\n".join(chunk)
        chunk_encodings = self.tokenizer(flatten_text, return_tensors='pt', add_special_tokens=False).input_ids
        return chunk_encodings.numpy()
    
    def __len__(self):
        if self.strided_indexing:
            # drop last by default
            return max(math.floor(self.seq_len / self.stride), 1)
        else:
            # return self.seq_len - self.stride * 2
            return self.seq_len - self.stride
    
    def __getitem__(self, idx):
        max_length = self.max_length
        assert max_length > 0
        
        if not self.strided_indexing:
            # idx = idx + self.stride
            begin_loc = idx
        else:
            begin_loc = idx * self.stride
        
        end_loc = min(begin_loc + max_length, self.seq_len)
        trg_len = end_loc - min(begin_loc - self.stride + max_length, self.seq_len)
        
        input_ids = self.encodings[:, begin_loc:end_loc]
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100
        
        if self.check_last_shape:
            if self.last_shape is not None:
                assert self.last_shape == input_ids.shape
            self.last_shape = input_ids.shape
        
        return {
            'input_ids': input_ids[0],
            'labels': target_ids[0],
            'trg_len': torch.tensor(trg_len),
        }

def get_dataloader(subset, tokenizer, batch_size=1, max_length=None, local_rank=0, world_size=1):
    assert max_length is not None
    ds = Wikitext2Dataset(subset, tokenizer, stride=max_length, max_length=max_length)
    use_shuffle = subset=='train'
    
    if world_size > 1:
        return DataLoader(
            ds, 
            batch_size=batch_size, 
            num_workers=0, 
            sampler=DistributedSampler(
                dataset=ds,
                num_replicas=world_size,
                rank=local_rank,
                shuffle=use_shuffle,
            )
        )
    else:
        return DataLoader(
            ds, 
            batch_size=batch_size, 
            num_workers=0, 
            shuffle=use_shuffle
        )

if __name__ == '__main__':
    import transformers
    t = transformers.AutoTokenizer.from_pretrained('facebook/opt-125m')
    # loader = get_dataloader('train', t, batch_size=1, max_length=768)
    loader = get_dataloader('valid', t, batch_size=1, max_length=768)
    
    for batch in tqdm.tqdm(loader):
        # print([(k, v.shape) for k, v in batch.items()])
        pass