import json

import numpy as np
import tensorflow as tf
from transformers import GPT2TokenizerFast

from microsoft_nlp.paths import (
    owt2_tokenized,
    owt2_tokenized_nprecords,
    owt2_tokenized_tfrecords,
    owt_openai_samples,
)


class Dataset:  # TODO: Stub for type inference. Refactor dup code from subclasses.
    pass


class OWTSample(Dataset):
    """Interface for OpenWebText dataset sample released with GPT-2.

    https://github.com/openai/gpt-2-output-dataset/
    """

    def __init__(self, eod_token=50256):
        self.splits = {
            "train": ("webtext.train.jsonl", 250000),
            "validation": ("webtext.valid.jsonl", 5000),
            "test": ("webtext.test.jsonl", 5000),
        }

        # https://huggingface.co/gpt2/resolve/main/vocab.json
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.eod_token = np.asarray(eod_token, dtype=np.uint16)
        assert self.vocab["<|endoftext|>"] == self.eod_token

    def __str__(self):
        lines = [
            f"{type(self).__name__}:",
            f" # Vocab: {len(self.vocab)}",
        ]
        for split_name, (split_fname, split_size) in self.splits.items():
            lines.append(f" # {split_name}: {split_size}")
        return "\n".join(lines)

    @property
    def vocab(self):
        return self.tokenizer.vocab

    def get_splits(self, *args, **kwargs):
        return {split: self.get_split(split, *args, **kwargs) for split in self.splits}

    def get_split(
        self,
        split,
        seq_len=1024,
        batch_size=512,
        shuffle_buffer=10000,
        repeat=None,
        pad_value=None,
        dtype=tf.uint16,
    ):
        if split not in self.splits:
            raise IndexError(
                f"unrecognized split; '{split}' not in: {list(self.splits.keys())}"
            )
        split_fname, split_size = self.splits[split]

        def split_input_target(sequence):
            return sequence[:-1], sequence[1:]

        def gen():
            for i, line in enumerate(open(owt_openai_samples / split_fname)):
                doc = json.loads(line)
                tokens = self.tokenizer.encode(doc["text"])
                # A lot of documents have length 1024 but only 1023 tokens, but I don't
                #  see the pattern for when this happens or doesn't happen.
                if len(tokens) != doc["length"] and not (
                    len(tokens) == 1023 and doc["length"] == 1024
                ):
                    print(f"Skipping line {i} with inconsistent tokenization: {doc}")
                    continue
                if pad_value is None:
                    yield np.concatenate(
                        (
                            np.asarray(tokens, dtype=dtype.as_numpy_dtype()),
                            (self.eod_token,),
                        )
                    )
                else:
                    out = np.full(
                        shape=seq_len + 1,
                        fill_value=pad_value,
                        dtype=dtype.as_numpy_dtype(),
                    )
                    out[: len(tokens)] = tokens
                    out[
                        len(tokens)
                    ] = self.eod_token  # redundant when eod_token == pad_value
                    yield out

        dataset = tf.data.Dataset.from_generator(
            gen, output_signature=(tf.TensorSpec(shape=(None,), dtype=dtype))
        )
        # If padding is used then the generator always outputs sequences of length
        #  seq_len + 1, otherwise they need to be reshaped as follows.
        if pad_value is None:
            dataset = dataset.unbatch().batch(seq_len + 1, drop_remainder=True)
        dataset = dataset.map(split_input_target)
        if shuffle_buffer:
            dataset = dataset.shuffle(shuffle_buffer)
        if batch_size:
            dataset = dataset.batch(batch_size, drop_remainder=True)

        repeat = split.startswith("train") if repeat is None else repeat
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset


class OWT2(Dataset):
    """Interface for OpenWebText2 TF datasets for language modelling."""

    n_files = 161

    def __init__(
        self,
        splits={"train": 0.96, "test": 0.03, "validation": 0.01},
        order_seed=0,
        eod_token=50256,
    ):
        self.splits = splits
        self.order_seed = order_seed

        self.files = sorted(list(owt2_tokenized.glob("*.npz")))
        if len(self.files) != self.n_files:
            raise RuntimeError(
                f"expected {self.n_files} files, but saw {len(self.files)}; see README"
            )

        # Build an index mapping document ID to NPZ handler and position so that
        #  documents can be loaded in lazily and shuffled in any order before splitting.
        self.document_to_callable = []
        for file_path in self.files:
            handler = np.load(file_path)
            for i in range(len(handler)):
                self.document_to_callable.append(
                    lambda handler=handler, i=i: handler[f"arr_{i}"]
                )

        # Shuffle the order of all the documents and then split by document index.
        self.order = np.random.RandomState(seed=order_seed).permutation(
            self.n_documents
        )
        self.cutoffs = (np.cumsum(list(splits.values())) * self.n_documents).astype(int)
        if self.cutoffs[-1] != self.n_documents:
            raise ValueError("splits must sum to 1")
        self.order_splits = dict(
            zip(splits.keys(), np.split(self.order, self.cutoffs[:-1]))
        )

        # Check that all of the splits cover the list of all of the documents.
        assert np.all(np.concatenate(list(self.order_splits.values())) == self.order)

        # https://huggingface.co/gpt2/resolve/main/vocab.json
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.eod_token = np.asarray(eod_token, dtype=np.uint16)
        assert self.vocab["<|endoftext|>"] == self.eod_token

    def __str__(self):
        lines = [
            f"{type(self).__name__}:",
            f" # Vocab: {len(self.vocab)}",
            f" # Documents: {self.n_documents}",
        ]
        for split in self.splits:
            lines.append(f" - {split} - {len(self.order_splits[split])}")
        return "\n".join(lines)

    @property
    def n_documents(self):
        return len(self.document_to_callable)

    @property
    def vocab(self):
        return self.tokenizer.vocab

    def get_document(self, item, pad_value=None, seq_len=None):
        # Note: This is not thread safe! Due to np.load/zipfile implementation.
        # https://stackoverflow.com/questions/5624669/strange-badzipfile-bad-crc-32-problem
        if not (0 <= item <= self.n_documents - 1):
            raise IndexError(f"index {item} out of bounds: [0, {self.n_documents - 1}]")

        doc = self.document_to_callable[item]()
        if pad_value is None:
            return np.concatenate((doc, (self.eod_token,)))
        else:
            doc_length = len(doc) + 1  # document length including eod_token
            n_blocks = -(-doc_length // seq_len)  # ceil of doc_length / seq_len
            full_length = n_blocks * seq_len

            out = np.full(shape=full_length, fill_value=pad_value, dtype=doc.dtype)
            out[: len(doc)] = doc
            out[len(doc)] = self.eod_token  # redundant when eod_token == pad_value
            return out

    def __getitem__(self, item):
        return self.get_document(item)

    def get_splits(self, *args, **kwargs):
        return {split: self.get_split(split, *args, **kwargs) for split in self.splits}

    def get_split(
        self,
        split,
        seq_len=1024,
        batch_size=512,
        shuffle_buffer=10000,
        repeat=None,
        pad_value=None,
        dtype=tf.uint16,
    ):
        if split not in self.splits:
            raise IndexError(
                f"unrecognized split; '{split}' not in: {list(self.splits.keys())}"
            )
        order_split = self.order_splits[split]

        # TODO: This mostly duplicates the example in data.py.
        def split_input_target(sequence):
            return tf.cast(sequence[:-1], dtype), tf.cast(sequence[1:], dtype)

        def get_document(item):
            return self.get_document(item, pad_value=pad_value, seq_len=seq_len + 1)

        # TODO: Just preload everything into RAM and/or avoid use of tf.numpy_function.
        #  Can probably parallelize it then.
        dataset = (
            tf.data.Dataset.from_tensor_slices(order_split)
            .map(
                lambda item: tf.numpy_function(get_document, inp=[item], Tout=np.uint16)
            )
            .unbatch()
            .batch(seq_len + 1, drop_remainder=True)
            .map(split_input_target)
        )
        if shuffle_buffer:
            dataset = dataset.shuffle(shuffle_buffer)
        if batch_size:
            dataset = dataset.batch(batch_size, drop_remainder=True)

        repeat = split.startswith("train") if repeat is None else repeat
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset


class OWT2Records(Dataset):
    """Interface for OpenWebText2 TF datasets for language modelling."""

    n_files = 961
    tokens_per_file = 1024 * 512 * 16

    exts = {"np": ".npy", "tf": ".tfrecord"}
    dirs = {"np": owt2_tokenized_nprecords, "tf": owt2_tokenized_tfrecords}

    def __init__(
        self,
        splits={"train": 0.96, "test": 0.03, "validation": 0.01},
        order_seed=0,
        eod_token=50256,
        kind="np",
    ):
        self.splits = splits
        self.order_seed = order_seed

        assert kind in self.dirs
        self.kind = kind
        self.ext = self.exts[kind]

        self.files = sorted(list(self.dirs[kind].glob(f"*{self.ext}")))
        if len(self.files) != self.n_files:
            raise RuntimeError(
                f"expected {self.n_files} files, but saw {len(self.files)}; see README"
            )

        self.n_tokens = self.n_files * self.tokens_per_file

        # documents are already shuffled before making record files, but shuffle
        # the files too just to be safe
        self.order = (
            np.arange(self.n_files)
            if order_seed is None
            else np.random.RandomState(seed=order_seed).permutation(self.n_files)
        )

        # determine splits (currently based of number of files for fast access)
        self.cutoffs = (np.cumsum(list(splits.values())) * self.n_files).astype(int)
        if self.cutoffs[-1] != self.n_files:
            raise ValueError("splits must sum to 1")
        self.order_splits = dict(
            zip(splits.keys(), np.split(self.order, self.cutoffs[:-1]))
        )

        # Check that all of the splits cover the list of all of the documents.
        assert np.all(np.concatenate(list(self.order_splits.values())) == self.order)

        # https://huggingface.co/gpt2/resolve/main/vocab.json
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.eod_token = np.asarray(eod_token, dtype=np.uint16)
        assert self.vocab["<|endoftext|>"] == self.eod_token

    def __str__(self):
        lines = [
            f"{type(self).__name__}:",
            f" # Vocab: {len(self.vocab)}",
            f" # Tokens: {self.n_tokens}",
        ]
        for split in self.splits:
            lines.append(f" - {split} - {len(self.order_splits[split])}")
        return "\n".join(lines)

    @property
    def vocab(self):
        return self.tokenizer.vocab

    def get_splits(self, *args, **kwargs):
        return {split: self.get_split(split, *args, **kwargs) for split in self.splits}

    def get_split(self, split, **kwargs):
        if split not in self.splits:
            raise IndexError(
                f"unrecognized split; '{split}' not in: {list(self.splits.keys())}"
            )

        if self.kind == "np":
            return self._get_split_np(split, **kwargs)
        # elif self.kind == "tf":
        #     return self._get_split_tf(split, **kwargs)
        else:
            raise NotImplementedError(self.kind)

    def _get_split_np(
        self,
        split,
        seq_len=1024,
        batch_size=512,
        shuffle_buffer=10000,
        pad_value=None,
        dtype=tf.uint16,
        targets=True,
        cycle_length=None,
        block_length=None,
        repeat=None,
    ):
        assert pad_value is None, "OWT2Records does not yet support padding"

        def tf_open_npy(filename):
            (data,) = tf.py_function(
                lambda f: np.load(f.numpy().decode()),
                [filename],
                [dtype],
            )
            dataset = tf.data.Dataset.from_tensor_slices(data)
            blen = seq_len + (1 if targets else 0)
            dataset = dataset.batch(blen, drop_remainder=True)
            return dataset

        def split_input_target(sequence):
            return sequence[:-1], sequence[1:]

        order_split = self.order_splits[split]
        files = [str(self.files[i]) for i in order_split]

        dataset = tf.data.Dataset.from_tensor_slices(files)
        dataset = dataset.interleave(
            tf_open_npy, cycle_length=cycle_length, block_length=block_length
        )
        if targets:
            dataset = dataset.map(split_input_target)
        if shuffle_buffer:
            dataset = dataset.shuffle(shuffle_buffer)
        if batch_size:
            dataset = dataset.batch(batch_size, drop_remainder=True)

        repeat = split.startswith("train") if repeat is None else repeat
        if repeat:
            dataset = dataset.repeat()

        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

    # def _get_split_tf(
    #     self,
    #     split,
    #     seq_len=1024,
    #     batch_size=512,
    #     shuffle_buffer=10000,
    #     dtype=tf.uint16,
    #     targets=True,
    # ):
    #     def decode_example(raw_example):
    #         featuremap = {
    #             "n_tokens": tf.io.FixedLenFeature((), tf.int64),
    #             "tokens": tf.io.VarLenFeature(tf.string),
    #             # "tokens": tf.io.VarLenFeature(tf.string),
    #         }

    #         features = tf.io.parse_single_example(raw_example, featuremap)
    #         # data = tf.py_function(
    #         #     lambda s: np.frombuffer(s, dtype=np.uint16),
    #         #     # [features["tokens"]],
    #         #     [features["tokens"].values[0]],
    #         #     [tf.uint16],
    #         # )

    #         (data,) = tf.py_function(
    #             lambda s: np.frombuffer(s, dtype=np.uint16),
    #             # [features["tokens"]],
    #             [features["tokens"].values[0]],
    #             [tf.uint16],
    #         )

    #         return data

    #     # def split_input_target(sequence):
    #     #     return tf.cast(sequence[:-1], dtype), tf.cast(sequence[1:], dtype)

    #     order_split = self.order_splits[split]
    #     files = [str(self.files[i]) for i in order_split[:2]]

    #     dataset = tf.data.TFRecordDataset(files)
    #     dataset = dataset.map(decode_example)
    #     return dataset
    #     # dataset = dataset.interleave(tf_open_npy)
    #     # return (
    #     #     dataset.shuffle(shuffle_buffer)
    #     #     .batch(batch_size, drop_remainder=True)
    #     #     .prefetch(tf.data.experimental.AUTOTUNE)
    #     #     .repeat()
    #     # )
