import os
import numpy as np
import torch
import torch.utils.data as data
import transformers
import datasets
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import torchvision
import datasets

from collections import deque
from pytorch_pretrained_bert import BertTokenizer
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from tqdm import tqdm

def create_imagenet_dataloader(dataset_name, split, batch_size, data_dir):
    if split == 'train':
        is_train = True
        split_dir = 'train'
    else:
        is_train = False
        split_dir = 'val'
    data_dir = os.path.join(data_dir, split_dir)
    num_threads = 32
    device_id = 0
    seed = 12 + device_id

    if is_train:
        pipe = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed)
        with pipe:
            jpegs, labels = fn.readers.file(name='Reader', file_root=data_dir, random_shuffle=True)
            images = fn.decoders.image_random_crop(jpegs, device='mixed', output_type=types.RGB,
                                                   random_area=[0.08, 1.0], random_aspect_ratio=[0.75, 1.333])
            images = fn.resize(images, resize_x=256, resize_y=256)
            mirror = fn.random.coin_flip(probability=0.5)
            images = fn.crop_mirror_normalize(images, device='gpu', dtype=types.FLOAT,
                                              output_layout=types.NHWC, crop=(224, 224),
                                              mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                              mirror=mirror)
            labels = labels.gpu()
            pipe.set_outputs(images, labels)
    else:
        pipe = Pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, seed=seed)
        with pipe:
            jpegs, labels = fn.readers.file(name='Reader', file_root=data_dir, random_shuffle=False)
            images = fn.decoders.image(jpegs, device='mixed', output_type=types.RGB)
            images = fn.resize(images, resize_shorter=256)
            images = fn.crop_mirror_normalize(images, device='gpu', dtype=types.FLOAT,
                                              output_layout=types.NHWC, crop=(224, 224),
                                              mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                              std=[0.229 * 255, 0.224 * 255, 0.225 * 255])
            labels = labels.gpu()
            pipe.set_outputs(images, labels)
    pipe.build()

    output_map = ['images', 'labels']
    dali_iter = DALIGenericIterator(
        [pipe],
        output_map=output_map,
        reader_name='Reader',
        last_batch_policy=LastBatchPolicy.DROP,
        auto_reset=True
    )

    return dali_iter

def get_imagenet_dataloaders(batch_size = 256):
    train_loader = create_imagenet_dataloader('imagenet', 'train', batch_size, '../pytorch_imagenet_data/pytorch_imagenet_data')
    val_loader = create_imagenet_dataloader('imagenet', 'val', batch_size, '../pytorch_imagenet_data/pytorch_imagenet_data')
    return train_loader, val_loader

def cifar_dataset(batch_size = 256):
    mean = (0.4914, 0.4822, 0.4465)
    std  = (0.2470, 0.2435, 0.2616)
    train_tfms = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, padding=4),
                                    torchvision.transforms.RandomHorizontalFlip(),
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(mean, std)])
    test_tfms  = torchvision.torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize(mean, std)])

    train_ds = torchvision.datasets.CIFAR10(train=True,  download=True, transform=train_tfms)
    test_ds  = torchvision.datasets.CIFAR10(train=False, download=True, transform=test_tfms)
    train_loader = data.DataLoader(train_ds, batch_size=128, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader  = data.DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader

class StreamingTextDataset(data.IterableDataset):
    def __init__(self, hf_streaming_ds, tokenizer, *, text_key = 'text', seq_len = 1024, tokenize_batch_size = 2048, add_eos_between_docs = True, eos_token_id = None):
        super().__init__()
        self.hf_ds = hf_streaming_ds
        self.tokenizer = tokenizer
        self.text_key = text_key
        self.seq_len = seq_len
        self.tokenize_batch_size = tokenize_batch_size
        self.add_eos_between_docs = add_eos_between_docs

        if eos_token_id is None:
            eos_token_id = getattr(tokenizr, 'eos_token_id', None)
        if eos_token_id is None:
            eos_token_id = tokenizer.encode('\n', add_special_tokens=False)[0]

        self.eos_token_id = int(eos_token_id)

    def _iter_shard(self):
        info = data.get_worker_info()
        if info is None:
            return iter(self.hf_ds)
        return iter(self.hf_ds.shard(num_shards = info.num_workers, index = info.id))

    def _yield_chunks_from_buffer(self, buf):
        while len(buf) >= self.seq_len:
            chunk = [buf.popleft() for _ in range(self.seq_len)]
            yield chunk

    def _consume_and_chunk(self, texts, token_buffer):
        enc = self.tokenizer(texts, add_special_tokens = False, return_attention_mask = False, return_token_type_ids = False, truncation = False)
        input_ids_batch = enc['input_ids']
        if self.add_eos_between_docs:
            for ids in input_ids_batch:
                token_buffer.extend(ids)
                token_buffer.append(self.eos_token_id)
        else:
            for ids in input_ids_batch:
                token_buffer.extend(ids)

    def _make_example(self, chunk_ids):
        input_ids = torch.tensor(chunk_ids, dtype = torch.long)
        attn = torch.ones_like(input_ids)
        return {
            'input_ids': input_ids,
            'attention_mask': attn,
            'labels': input_ids.clone()
        }

    def __iter__(self):
        token_buffer = deque()
        pending_texts = []
        
        for ex in self._iter_shard():
            txt = ex[self.text_key]
            if not isinstance(txt, str):
                txt = str(txt)
            pending_texts.append(txt)

        if len(pending_texts) >= self.tokenize_batch_size:
            self._consume_and_chunk(pending_texts, token_buffer)
            pending_texts.clear()

            for chunk in self._yield_chunks_from_buffer(token_buffer):
                yield self._make_example(chunk)

        if pending_texts:
            self._consume_and_chunk(pending_texts, token_buffer)
            pending_texts.clear()
        
        for chunk in self._yield_chunks_from_buffer(token_buffer):
            yield self.make_example(chunk)

class SizedLoader:
    def __init__(self, loader, steps_per_epoch):
        self.loader = loader
        self.steps = int(steps_per_epoch)

    def __iter__(self):
        return iter(islice(self.loader, self.steps))

    def __len__(self):
        return self.steps

def build_text_dataloader(split, tokenizer, *, seq_len = 1024, batch_size = 64, num_workers = 16, shuffle = True, shuffle_buffer = 10_000, text_key = 'text', 
                            tokenize_batch_size = 2048, add_eos_between_docs = True, drop_last = True):
    ds = datasets.load_dataset('DKYoon/SlimPajama-6B', split = split, streaming = True)

    if shuffle:
        ds = ds.shuffle(buffer_size = shuffle_buffer)
    iter_ds = StreamingTextDataset(ds, tokenizer, text_key = text_key, seq_len = seq_len, tokenize_batch_size = tokenize_batch_size, add_eos_between_docs = add_eos_between_docs,
                                    eos_token_id = getattr(tokenizer, 'eos_token_ids', None))

    def collate(features):
        return {
            'input_ids': torch.stack([f['input_ids'] for f in features], dim=0),
            'attention_mask': torch.stack([f['attention_mask'] for f in features], dim=0),
            'labels': torch.stack([f['labels'] for f in features], dim=0),
        }
    
    if num_workers > 0:
        persistent_workers = True
        prefetch_factor = 2

    else:
        persistent_workers = False
        prefetch_factor = 1

    return data.DataLoader(iter_ds, batch_size = batch_size, collate_fn = collate, num_workers = num_workers, pin_memory = True, drop_last = drop_last,
                            persistent_workers = persistent_workers, prefetch_factor = prefetch_factor)

def sj_get_text_dataloaders(model_name, seq_len = 1024, batch_size = 64, num_workers = 16, tokenize_batch_size = 2048, shuffle_buffer = 10_000):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast = True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataloaders = []
    for split in ['train', 'validation']:
        dl = build_text_dataloader(split, tokenizer, seq_len = seq_len, batch_size = batch_size, num_workers = num_workers, shuffle = (split == 'train'), 
                                    tokenize_batch_size = tokenize_batch_size, shuffle_buffer = shuffle_buffer)
        dataloaders.append(dl)

    return dataloaders

def _format_alpaca(batch):
    texts = []
    for instr, inp, output in zip(batch.get('instruction', []), batch.get('input', []), batch.get('output', [])):
        if inp and len(inp.strip()) > 0:
            t = f'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instr}\n\n### Input:\n{inp}\n\n### Response:\n{output}'
        else:
            t = f'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n ### Instruction:\n{instr}\n\n### Response:\n{output}'
        texts.append(t)
    return {'text': texts}

def _tokenize_then_group(ds, tokenizer, seq_len, num_proc = 8):
    eos = tokenizer.eos_token or ''

    def tok(batch):
        texts = [x + eos for x in batch['text']]
        out = tokenizer(texts, add_special_tokens = False, return_attention_mask = True)
        return out

    tokenized = ds.map(tok, batched = True, remove_columns = ds.column_names, num_proc = num_proc)

    def group(batch):
        input_ids = sum(batch['input_ids'], [])
        attn = sum(batch['attention_mask'], [])
        total = (len(input_ids) // seq_len) * seq_len
        input_ids = input_ids[:total]
        attn = attn[:total]

        out = {'input_ids': [input_ids[i:i+seq_len] for i in range(0, total, seq_len)],
               'attention_mask': [attn[i:i+seq_len] for i in range(0, total, seq_len)]}
        return out

    chunked = tokenized.map(group, batched = True, remove_columns = tokenized.column_names, num_proc = num_proc)
    return chunked.with_format(type = 'torch', columns = ['input_ids', 'attention_mask'])

def _prepare_hf_dataset(model_name, seq_len = 1024, val_fraction = 0.02, num_proc = 8):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast = True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_size = 'left'

    raw = datasets.load_dataset('yahma/alpaca-cleaned')['train']
    formatted = raw.map(_format_alpaca, batched = True, remove_columns = raw.column_names, num_proc = num_proc, load_from_cache_file=False)

    formatted = formatted.shuffle()
    ds_dict = formatted.train_test_split(test_size = val_fraction)
    train = _tokenize_then_group(ds_dict['train'], tokenizer, seq_len, num_proc)
    val = _tokenize_then_group(ds_dict['test'], tokenizer, seq_len, num_proc)
    return tokenizer, datasets.DatasetDict(train = train, test = val)

def get_alpaca_text_dataloaders(model, batch_size = 256, seq_len = 1024, num_workers = 16, *, val_fraction = 0.02, num_proc = 8):
    _, ds = _prepare_hf_dataset(model_name = model, seq_len = seq_len, val_fraction = val_fraction, num_proc = num_proc)

    train_loader = data.DataLoader(ds['train'], batch_size = batch_size, shuffle = True, drop_last = True, num_workers = num_workers, pin_memory = True,
                                   persistent_workers = (num_workers > 0), prefetch_factor = 2 if num_workers > 0 else None)

    val_loader = data.DataLoader(ds['test'], batch_size = batch_size, shuffle = False, drop_last = False, num_workers = num_workers, pin_memory = True,
                                 persistent_workers = (num_workers > 0), prefetch_factor = 2 if num_workers > 0 else None)

    return train_loader, val_loader

class TextDataset(data.Dataset):
    def __init__(self, dataset, num_words):
        self.dataset = dataset
        self.num_words = num_words
    
    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, idx):
        return {'input_ids': self.dataset[idx], 'target_ids': self.dataset[idx]}

def get_text_dataloaders(model, batch_size, num_workers = 16, seq_len = 1024):
    tokenizer = transformers.AutoTokenizer.from_pretrained(model, do_lower_case = False)
    dataloaders = []
    file_path = 'wikitext-103/wiki.{}.tokens'
    if not os.path.exists(os.path.join('cached_wikitext')):
        os.makedirs(os.path.join('cached_wikitext'))
    cached_file_path = os.path.join('cached_wikitext', 'wikitext_{}.pt')
    for split in ['train', 'valid', 'test']:
        if os.path.exists(cached_file_path.format(split)):
            wikitext_dataset = torch.load(cached_file_path.format(split))
            num_words = wikitext_dataset['num_words']
            dataset = wikitext_dataset['dataset']
        else:
            with open(file_path.format(split), 'r', encoding = 'utf-8') as f:
                dataset = f.readlines()
                num_words = sum([len(line.split()) for line in dataset])
            dataset = list(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line.strip(' ').replace('\n', '[SEP]').replace('<unk>', '[UNK]'))) for line in tqdm(dataset))
            dataset = torch.tensor([index for line in dataset for index in line], dtype = torch.long)
            torch.save({'dataset': dataset, 'num_words': num_words}, cached_file_path.format(split))
        num_sequences = (dataset.size(0) // seq_len) * seq_len
        dataset = dataset.narrow(0, 0, num_sequences).view(-1, seq_len)
        dataset = TextDataset(dataset, num_words)
        dataloader = data.DataLoader(dataset, batch_size = batch_size, num_workers = num_workers, shuffle = (split == 'train'))
        dataloaders.append(dataloader)
    return dataloaders[0], dataloaders[1], dataloaders[2], len(tokenizer.vocab)