# this file generates a train / test split from the NLI train data (to avoid distribution shift)
import csv
import os
import pandas as pd
import numpy as np


TRAIN_FRACTION = 0.9
TRAIN_SPLIT_NAME = 'nli_for_simcse.csv'

train = pd.read_csv(
    TRAIN_SPLIT_NAME,
    sep=',', quoting=csv.QUOTE_ALL, header=0)

total_size = train.shape[0]
print("Initial size: ", total_size)
train.drop_duplicates(inplace=True)
total_size = train.shape[0]
print("Size after dropping duplicates: ", total_size)

# we need to shuffle because there is a distribution shift in the middle of the dataset.
# potentially because its MNLI attached to SNLI?
train = train.sample(frac=1)

num_train1 = int(TRAIN_FRACTION * total_size)
num_train2 = total_size - num_train1

train_train = train.head(num_train1 - 1)
train_test = train.tail(num_train2 + 1)

train_train.to_csv('train_train_NLI.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
train_test.to_csv('train_test_NLI.csv', index=False, quoting=csv.QUOTE_ALL, header=False)