from csv import reader
from torchtext.data.utils import get_tokenizer
import os

def next_token(data_path):
  tokenizer = get_tokenizer("basic_english")
  with open(data_path, 'r') as read_obj:
    csv_reader = reader(read_obj)
    for row in csv_reader:
      tokens = ' '.join(row[1:])
      tokens = tokenizer(tokens)
      yield int(row[0]) - 1, tokens


class ClassificationDatasetSplits:
    def __init__(self, root='./data'):
        self.root = root
        self.train_path = os.path.join(root, 'train.csv')
        self.test_path = os.path.join(root, 'test.csv')

    def splits(self):
        return next_token(self.train_path), next_token(self.test_path)
