from datasets import load_dataset
import numpy as np
import os
import tiktoken
import torch
from tqdm.auto import tqdm


class Data:
    def __init__(self, name='ts', root='_data', batch_trn=6, batch_tst=6,
                 size_trn=100, size_tst=100, block_size=1024, d=None):
        assert name == 'ts'

        self.name = name
        self.root = root
        self.batch_trn = batch_trn
        self.batch_tst = batch_tst
        self.size_trn = size_trn
        self.size_tst = size_tst
        self.block_size = block_size
        self.d = d

        self.enc = tiktoken.get_encoding('gpt2')

        self._load()

        self.loader_trn = DataLoader(self.get_batch_trn, size_trn)
        self.loader_tst = DataLoader(self.get_batch_tst, size_tst)

    def decode(self, output):
        return self.enc.decode(output.squeeze().tolist())

    def encode(self, sentence):
        return torch.tensor(self.enc.encode_ordinary(sentence)).unsqueeze(dim=0)

    def get_batch(self, split='trn'):
        """Get batch of train / test data.

        If self.d (number of shifts) is provided, when function will return
        a list of targets for all shifts (1, 2, ..., d).
        
        """
        fpath = self._get_fpath(tst=(split == 'tst'))
        data = np.memmap(fpath, dtype=np.uint16, mode='r')

        batch_size = self.batch_trn if split == 'trn' else self.batch_tst
        block_size = self.block_size
        
        if self.d is None:
            ix = torch.randint(len(data) - block_size, (batch_size,))
            x = torch.stack([
                torch.from_numpy((data[i:i+block_size]).astype(np.int64))
                for i in ix])
            y = torch.stack([
                torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64))
                for i in ix])
            return x, y
        
        else:
            ix = torch.randint(len(data) - block_size-self.d+1, (batch_size,))
            x = torch.stack([
                torch.from_numpy((data[i:i+block_size]).astype(np.int64))
                for i in ix])
            y_all = []
            for k in range(self.d):
                sz = block_size + k
                y = torch.stack([
                    torch.from_numpy((data[i+1+k:i+1+sz]).astype(np.int64))
                    for i in ix])
                y_all.append(y)
            return x, y_all

    def get_batch_trn(self):
        return self.get_batch(split='trn')

    def get_batch_tst(self):
        return self.get_batch(split='tst')

    def _get_fpath(self, tst=True):
        fname = ('tst' if tst else 'trn') + '.bin'
        fpath = os.path.join(self.root, self.name, fname)
        return fpath

    def _load(self, batches=1024):
        if os.path.exists(self._get_fpath()):
            return
        os.makedirs(os.path.join(self.root, self.name), exist_ok=True)

        ds = load_dataset('roneneldan/TinyStories',
            cache_dir=os.path.join(self.root, self.name, 'cache'))
        
        def process(example):
            ids = self.enc.encode_ordinary(example['text'])
            out = {'ids': ids, 'len': len(ids)}
            return out

        tokenized = ds.map(process,
            remove_columns=['text'], desc='tokenizing the splits', num_proc=8)

        for split, dset in tokenized.items():
            fpath = self._get_fpath(tst=(split != 'train'))
            dtype = np.uint16 # (enc.max_token_value == 50256 is < 2**16)
            sz = np.sum(dset['len'], dtype=np.uint64)
            arr = np.memmap(fpath, dtype=dtype, mode='w+', shape=(sz,))

            idx = 0
            for batch_idx in tqdm(range(batches), desc=f'writing {fpath}'):
                # Batch together samples for faster write:
                batch = dset.shard(num_shards=batches, index=batch_idx,
                    contiguous=True).with_format('numpy')
                arr_batch = np.concatenate(batch['ids'])
                arr[idx:idx+len(arr_batch)] = arr_batch
                idx += len(arr_batch)
            arr.flush()


class DataLoader:
    def __init__(self, func, steps):
        self.func = func
        self.steps = steps
        
    def __iter__(self):
        self.step = 1
        return self
    
    def __next__(self):
        if self.step <= self.steps:
            x = self.step
            self.step += 1
            return self.func()
        else:
            raise StopIteration