import numpy as np
import pytest
import tqdm

from microsoft_nlp.owt2_archiver import Reader
from microsoft_nlp.owt2_loader import OWT2, OWT2Records
from microsoft_nlp.paths import owt2_raw


def test_openwebtext2_raw():
    if not owt2_raw.is_dir():
        pytest.skip()  # raw dataset is not strictly required

    # Adapted from https://openwebtext2.readthedocs.io/en/latest/
    files = owt2_raw.glob("*.jsonl.zst")

    document_count = 0
    total_text_size = 0
    for file_path in tqdm.tqdm(files, dynamic_ncols=True):
        reader = Reader()
        for document, metadata in reader.read_jsonl(file_path, get_meta=True):
            document_count += 1
            total_text_size += len(document)

        # Stop after the first record.
        break

    assert document_count == 70993
    assert total_text_size == 207998097


def test_openwebtext2_loader():
    owt2 = OWT2()
    document = owt2[0]
    assert document.dtype == np.uint16
    assert document[-1] == owt2.eod_token
    assert owt2.n_documents == 9127994

    gen = owt2.get_split("train")
    X, Y = next(gen.as_numpy_iterator())
    assert X.shape == Y.shape == (512, 1024)


def test_openwebtext2_padding():
    owt2 = OWT2()

    pad_value = 2 ** 16 - 1
    gen = owt2.get_split("train", pad_value=pad_value)
    for i, (X, Y) in enumerate(gen.as_numpy_iterator()):
        assert X.shape == Y.shape == (512, 1024)

        for x in X:
            (i_eod,) = (x == owt2.eod_token).nonzero()
            assert i_eod.size <= 1, "Found multiple EOD tokens when padding"
            if i_eod.size > 0 and i_eod[0] < x.size - 1:
                assert (x[i_eod[0] + 1 :] == pad_value).all()

        if i >= 10:
            break


def test_owt2_records_loader():
    owt2 = OWT2(order_seed=0)  # this is the seed that was used to shuffle records
    owt2_records = OWT2Records(order_seed=None)  # order_seed=None for no file shuffling

    gen = owt2.get_split("train", shuffle_buffer=None)
    # setting cycle_length=1 and block_length to 4*batch_size make things deterministic
    # for the first 4 batches
    gen_records = owt2_records.get_split(
        "train", shuffle_buffer=None, cycle_length=1, block_length=2048
    )

    for k, ((X1, Y1), (X2, Y2)) in enumerate(zip(gen, gen_records)):
        assert X1.shape == X2.shape
        assert Y1.shape == Y2.shape
        assert np.allclose(X1, X2)
        if k > 3:
            break
