from random import seed
import pandas as pd
import nlpaug.augmenter.word as naw
from tqdm import tqdm

tqdm.pandas()
aug = naw.BackTranslationAug(from_model_name="Helsinki-NLP/opus-mt-en-fr", to_model_name="Helsinki-NLP/opus-mt-fr-en", device='cuda', batch_size=2, max_length=512)
df = pd.read_csv('data/samples/the_pile_sample_512.csv').sample(frac=0.2, random_state=42)

# df['augs'] = df['text'].apply(lambda x: aug.augment(x)[0])
df = df.loc[~df['corpus'].isin(['europarl'])]
augs = aug.augment(df['text'].to_list())
df['augs'] = augs
print(df.head(10))
df[['doc_id', 'corpus', 'text', 'augs']].to_csv('data/samples/the_pile_sample_aug.csv', index=False)
