import csv
import os
import pandas as pd
import numpy as np


cwd = os.getcwd()
TRAIN_SPLIT_NAME = 'mnli_train-train.csv'
TEST_SPLIT_NAME = 'mnli_train-test.csv'
fraction_train_1 = 0.8

train_train = pd.read_csv(
    TRAIN_SPLIT_NAME,
    sep=',', quoting=csv.QUOTE_ALL, header=None)

train_test = pd.read_csv(
    TEST_SPLIT_NAME,
    sep=',', quoting=csv.QUOTE_ALL, header=None)

# first clean the full dataframes from duplicates
# TRAIN
total_size = train_train.shape[0]
print("Initial size: ", total_size)
train_train.drop_duplicates(inplace=True)
total_size = train_train.shape[0]
print("Size after dropping duplicates: ", total_size)
# TEST
total_size = train_test.shape[0]
print("Initial size: ", total_size)
train_test.drop_duplicates(inplace=True)
total_size = train_test.shape[0]
print("Size after dropping duplicates: ", total_size)

# then split the train in two different sets
total_size = train_train.shape[0]

num_train1 = int(fraction_train_1*total_size)
num_train2 = total_size - num_train1

train_train1 = train_train.head(num_train1 - 1)
train_train2 = train_train.tail(num_train2 + 1)

# flatten both dataframes (because duplicate columns can occur in q1 and q2)
train_train_flat = pd.concat([train_train[col] for col in train_train])
train_train_flat1 = pd.concat([train_train1[col] for col in train_train1])
train_train_flat2 = pd.concat([train_train2[col] for col in train_train2])
train_test_flat = pd.concat([train_test[col] for col in train_test])

# intersection make sure that there are no intersections between train and test...
intersect = pd.Series(list(set(train_train_flat).intersection(set(train_test_flat))))
train_train_clean1 = train_train_flat1[~train_train_flat1.isin(intersect)]
train_train_clean2 = train_train_flat2[~train_train_flat2.isin(intersect)]
train_test_clean = train_test_flat[~train_test_flat.isin(intersect)]

# make also sure that there is no intersection between the train1 and train2
intersect2 = pd.Series(list(set(train_train_clean1).intersection(set(train_train_clean2))))
train_train_clean1 = train_train_clean1[~train_train_clean1.isin(intersect2)]
train_train_clean2 = train_train_clean2[~train_train_clean2.isin(intersect2)]

# there can still be duplicates in the data (because a question can be listed either in left or right column)
total_size = train_train_clean1.shape[0]
print("Initial size: ", total_size)
train_train_clean1.drop_duplicates(inplace=True)
total_size = train_train_clean1.shape[0]
print("Size after dropping duplicates: ", total_size)

total_size = train_train_clean2.shape[0]
print("Initial size: ", total_size)
train_train_clean2.drop_duplicates(inplace=True)
total_size = train_train_clean2.shape[0]
print("Size after dropping duplicates: ", total_size)

total_size = train_test_clean.shape[0]
print("Initial size: ", total_size)
train_test_clean.drop_duplicates(inplace=True)
total_size = train_test_clean.shape[0]
print("Size after dropping duplicates: ", total_size)


train_train_clean1.to_csv('DI-cleaned-1-mnli.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
train_train_clean2.to_csv('DI-cleaned-2-mnli.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
train_test_clean.to_csv('DI-cleaned-mnli.csv', index=False, quoting=csv.QUOTE_ALL, header=False)