R"""

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/misc1/math_dataset_dev001.py


"""
import collections
import dataclasses
import os
from importlib import reload
import itertools
import time
from typing import Any, List, Sequence

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
from torchnmf.nmf import NMF as TorchNMF
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets import math_dataset
from em.util import vat_da_faak_vpn

###############################################################################

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# ndb.separate_tf_torch_gpus()

###############################################################################

PRETRAINED_MODEL = "prajjwal1/bert-small"

SEQUENCE_LENGTH = 128

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

ds = math_dataset.load_true_false_only(
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)

for x, y in itertools.islice(ds, 128):
    print(x, y)
