from datasets import load_dataset
import os
import json


#Pair sentences

#SNLI
#data_list = [["snli","plain text"]]
#data_list = [["snli","plain_text"],["anli","plain_text"],["amazon_polarity","amazon_polarity"]]
#data_list = [["snli","plain_text"],["anli","plain_text"],["yelp_polarity"]]
#data_list = [["snli","plain_text"],["anli","plain_text"],["recast","recast_puns"]]
#data_list = [["snli","plain_text"],["recast","recast_puns"]]
#data_list = [["movie_rationales"],["default"]]
#data_list = [["movie_rationales"]]
#data_list = [["tweet_eval","sentiment"]]
#data_list = [["tweeteval","sentiment"]]
#data_list = [["tweeteval","sentiment"]]
data_list = [["anli"]]


for data_name in data_list:
    if "snli" in data_name and os.path.isdir("snli"):
        #print("snli", os.path.isdir("snli"))
        print("Already exist: ", "snli")
        continue
    elif "anli" in data_name and os.path.isdir("anli"):
        #print("anli", os.path.isdir("anli"))
        print("Already exist: ", "anli")
        continue
    elif os.path.isdir(data_name[-1].replace("_","-")):
        #print(data_name[-1], os.path.isdir(data_name[-1]))
        print("Already exist: ", data_name[-1])
        continue

    if len(data_name) == 1:
        dataset = load_dataset(data_name[0])
        print(dataset)
        print("=======================")
    elif len(data_name) == 2:
        dataset = load_dataset(data_name[0], data_name[1])
        print(dataset)
        print("=======================")
    else:
        print("Wrong")
        exit()

    print(dataset)

    exit()


    if data_name[0] == "snli" or data_name[0] == "yelp_polarity":
        data_name[0] = data_name[0].replace("_","-")
        os.mkdir(data_name[0])
        if "train" in dataset:
            with open(data_name[0]+"/train.json","w") as f:
                train_save_list = [line for line in dataset["train"]]
                json.dump(train_save_list, f)
        if "validation" in dataset:
            with open(data_name[0]+"/dev.json","w") as f:
                validation_save_list = [line for line in dataset["validation"]]
                json.dump(validation_save_list, f)
        if "test" in dataset:
            with open(data_name[0]+"/test.json","w") as f:
                test_save_list = [line for line in dataset["test"]]
                json.dump(test_save_list, f)

    elif data_name[0] == "anli":
        data_name[0] = data_name[0].replace("_","-")
        os.mkdir(data_name[0])
        for r in ["r1","r2","r3"]:
            with open(data_name[0]+"/train_"+r+".json","w") as f:
                train_save_list = [line for line in dataset["train_"+r]]
                json.dump(train_save_list, f)
            with open(data_name[0]+"/dev_"+r+".json","w") as f:
                validation_save_list = [line for line in dataset["dev_"+r]]
                json.dump(validation_save_list, f)
            with open(data_name[0]+"/test_"+r+".json","w") as f:
                test_save_list = [line for line in dataset["test_"+r]]
                json.dump(test_save_list, f)

    else:
        data_name[-1] = data_name[-1].replace("_","-")
        os.mkdir(data_name[-1])
        if "train" in dataset:
            with open(data_name[-1]+"/train.json","w") as f:
                train_save_list = [line for line in dataset["train"]]
                json.dump(train_save_list, f)
        if "validation" in dataset:
            with open(data_name[-1]+"/dev.json","w") as f:
                validation_save_list = [line for line in dataset["validation"]]
                json.dump(validation_save_list, f)
        if "test" in dataset:
            with open(data_name[-1]+"/test.json","w") as f:
                test_save_list = [line for line in dataset["test"]]
                json.dump(test_save_list, f)



#SNLI
'''
DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})
'''

#ANLI
'''
DatasetDict({
    train_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 16946
    })
    dev_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r1: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 45460
    })
    dev_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1000
    })
    train_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 100459
    })
    dev_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200
    })
    test_r3: Dataset({
        features: ['uid', 'premise', 'hypothesis', 'label', 'reason'],
        num_rows: 1200
    })
})
'''

#movie_rationales
'''
DatasetDict({
    train: Dataset({
        features: ['review', 'label', 'evidences'],
        num_rows: 1600
    })
    validation: Dataset({
        features: ['review', 'label', 'evidences'],
        num_rows: 200
    })
    test: Dataset({
        features: ['review', 'label', 'evidences'],
        num_rows: 199
    })
})
'''
