import pathlib

import numpy as np
import tqdm
from transformers import GPT2TokenizerFast

from microsoft_nlp.owt2_archiver import Reader
from microsoft_nlp.paths import owt2_raw, owt2_tokenized


def main():
    # https://huggingface.co/gpt2/resolve/main/vocab.json
    tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    assert len(tokenizer.vocab) == 50257  # fits in uint16
    files = sorted(list(owt2_raw.glob("*.jsonl.zst")))

    n_tokens = 0
    n_documents = 0
    for file_path in tqdm.tqdm(files, dynamic_ncols=True):
        name = pathlib.Path(file_path).with_suffix("").stem  # removes multiple suffixes
        year, month = map(int, name.split("-"))
        assert 2005 <= year <= 2020
        assert 1 <= month <= 12

        # https://arxiv.org/pdf/2001.08361.pdf includes up to October 2018
        #  so for consistency we filter out everything >= Nov 2018.
        if year >= 2019 or (year == 2018 and month >= 11):
            continue

        subdata = []
        reader = Reader()
        for document, metadata in reader.read_jsonl(file_path, get_meta=True):
            # Each document contains a list of all Reddit posts that pointed to
            #  that document. Only keep the document if at least one of the posts
            #  has karma >= 3.
            all_karma = metadata["reddit_scores"]
            assert len(metadata["reddit_ids"]) == len(all_karma)
            kept = len([karma for karma in all_karma if karma >= 3])

            # Also limit to English only.
            if kept > 0 and metadata["lang"] == "en":
                tokens = np.asarray(tokenizer.encode(document), dtype=np.uint16)
                subdata.append(tokens)
                n_tokens += len(tokens)
                n_documents += 1

        np.savez(owt2_tokenized / name, *subdata)

    print(f"Created {n_tokens} tokens from {n_documents} documents in total!")
    assert n_tokens == 8055703022
    assert n_documents == 9127994


if __name__ == "__main__":
    main()
