import csv
import os
import pandas as pd
import numpy as np




cwd = os.getcwd()
TRAIN_SPLIT_NAME = 'nli_for_simcse.csv'
TEST_SPLIT_NAME = 'all_nli_test.csv'


train_train = pd.read_csv(
    TRAIN_SPLIT_NAME,
    sep=',', quoting=csv.QUOTE_ALL, header=0)

train_test_flat = pd.read_csv(
    TEST_SPLIT_NAME,
    sep=',', quoting=csv.QUOTE_ALL, header=None)


# flatten both dataframes (because duplicate columns can occur in q1 and q2)
train_train_flat = pd.concat([train_train[col] for col in train_train])
train_test_flat = train_test_flat.squeeze() # convert to series

# union of the series
union = pd.Series(pd.concat([train_train_flat, train_test_flat]))

# intersection of the series
intersect = pd.Series(list(set(train_train_flat).intersection(set(train_test_flat))))

train_train_clean = train_train_flat[~train_train_flat.isin(intersect)]
train_test_clean = train_test_flat[~train_test_flat.isin(intersect)]

# there can still be duplicates in the data (because a question can be listed either in left or right column)
total_size = train_train_clean.shape[0]
print("Initial size: ", total_size)
train_train_clean.drop_duplicates(inplace=True)
total_size = train_train_clean.shape[0]
print("Size after dropping duplicates: ", total_size)


total_size = train_test_clean.shape[0]
print("Initial size: ", total_size)
train_test_clean.drop_duplicates(inplace=True)
total_size = train_test_clean.shape[0]
print("Size after dropping duplicates: ", total_size)


train_train_clean.to_csv('DI-cleaned-nli-train.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
train_test_clean.to_csv('DI-cleaned-nli-test.csv', index=False, quoting=csv.QUOTE_ALL, header=False)