import timeit

import numpy as np
import tensorflow as tf
from transformers import GPT2TokenizerFast

from microsoft_nlp.paths import owt2_tokenized


class TokenNpzToRecords:
    def __init__(self, order_seed=0, tokens_per_file=1024 * 512, eod_token=50256):
        self.order_seed = order_seed
        self.tokens_per_file = tokens_per_file

        self.files = sorted(list(owt2_tokenized.glob("*.npz")))

        # https://huggingface.co/gpt2/resolve/main/vocab.json
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.eod_token = np.asarray(eod_token, dtype=np.uint16)
        assert self.vocab["<|endoftext|>"] == self.eod_token

    @property
    def n_documents(self):
        return len(self.document_to_callable)

    @property
    def vocab(self):
        return self.tokenizer.vocab

    def convert_all(self, target_dir, kind="tfrecord"):
        timer = timeit.default_timer()

        # load all documents into memory
        all_docs = []
        for file_path in self.files:
            with np.load(file_path) as file_data:
                all_docs.extend(file_data.values())

            timer1 = timeit.default_timer()
            print(f"Loaded {file_path} ({timer1 - timer:0.1f} s)")
            timer = timer1

        docs = []
        docs_len = 0
        file_i = 0
        timer = timeit.default_timer()

        order = np.random.RandomState(seed=self.order_seed).permutation(len(all_docs))
        for doc_i in order:
            doc = all_docs[doc_i]
            assert doc.ndim == 1

            docs.append(doc)
            docs.append([self.eod_token])
            docs_len += len(doc) + 1

            if docs_len >= self.tokens_per_file:
                # create flat document
                flat_doc = np.concatenate(docs)
                assert flat_doc.dtype == docs[0].dtype
                assert flat_doc.ndim == 1

                docs.clear()
                docs_len = 0
                if len(flat_doc) > self.tokens_per_file:
                    flat_doc, doc = (
                        flat_doc[: self.tokens_per_file],
                        flat_doc[self.tokens_per_file :],
                    )
                    docs.append(doc)
                    docs_len += len(doc)

                file_name = target_dir / f"tokens{file_i:04d}"
                if kind == "tfrecord":
                    # create another TFRecord file
                    file_name = file_name.with_suffix(".tfrecord")
                    self.write_tfrecord(file_name, flat_doc)
                elif kind == "npy":
                    file_name = file_name.with_suffix(".npy")
                    self.write_npy(file_name, flat_doc)
                elif kind is None:
                    self.write_tfrecord(file_name.with_suffix(".tfrecord"), flat_doc)
                    self.write_npy(file_name.with_suffix(".npy"), flat_doc)
                else:
                    raise ValueError(f"Unrecognized type {kind}")

                timer1 = timeit.default_timer()
                print(f"Wrote {file_name} ({timer1 - timer:0.1f} s)")
                timer = timer1

                file_i += 1

    def write_tfrecord(self, file_name, flat_doc):
        assert flat_doc.ndim == 1
        feature = {
            "n_tokens": _int64_feature(len(flat_doc)),
            "tokens": _bytes_feature(flat_doc.tobytes()),
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        with tf.io.TFRecordWriter(str(file_name)) as writer:
            writer.write(example.SerializeToString())

    def write_npy(self, file_name, flat_doc):
        assert flat_doc.ndim == 1
        np.save(file_name, flat_doc)


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bytes_list_feature(values):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
