import sys
import pandas as pd
from datasets import load_dataset

if __name__ == '__main__':

    output_path = sys.argv[1]

    dataset = load_dataset("wmt14","de-en") 
    df = pd.DataFrame.from_dict(dataset['train'])   
    df['target'] = df.translation.apply(lambda x: x['de'])
    df['source'] = df.translation.apply(lambda x: x['en'])
    df[['target','source']].to_csv("{}/de_en_train.csv".format(output_path),index=False)  

    df = pd.DataFrame.from_dict(dataset['validation'])   
    df['target'] = df.translation.apply(lambda x: x['de'])
    df['source'] = df.translation.apply(lambda x: x['en'])
    df[['target','source']].to_csv("{}/de_en_test.csv".format(output_path),index=False)  

    dataset = load_dataset("wmt14","fr-en") 
    df = pd.DataFrame.from_dict(dataset['train'])   
    df['target'] = df.translation.apply(lambda x: x['fr'])
    df['source'] = df.translation.apply(lambda x: x['en'])
    df[['target','source']].to_csv("{}/fr_en_train.csv".format(output_path),index=False)  

    df = pd.DataFrame.from_dict(dataset['validation'])   
    df['target'] = df.translation.apply(lambda x: x['fr'])
    df['source'] = df.translation.apply(lambda x: x['en'])
    df[['target','source']].to_csv("{}/fr_en_test.csv".format(output_path),index=False)  

    dataset = load_dataset("wikipedia", "20220301.en")
    dataset['train'].to_pandas().to_csv("{}/wiki.txt",index=False) 

    dataset = load_dataset("cnn_dailymail","3.0.0") 
    dataset['train'].to_pandas().to_csv("{}/cnn_dailymail_train.csv".format(output_path),index=False)
    dataset['validation'].to_pandas().to_csv("{}/cnn_dailymail_test.csv".format(output_path),index=False)    