import csv
import os
import pandas as pd

FLICKR_PATH = "../flickr30k/flickr-train1.csv"  # this file is not shuffled, so we can easily generate train1_train2 from it
QQP_PATH = "../qqp/qqp_train.csv"
NLI_PATH = "../nli/train_train_NLI.csv"

DATA = "flickr"  # "flickr" "qqp" "nli"

NUM_TO_INSERT = 50000
NUM_TRAIN1 = 40000
NUM_TRAIN2 = NUM_TO_INSERT - NUM_TRAIN1

base_header = None

# set the paths
if DATA == "nli":
    BASE_DATA_PATH = NLI_PATH
elif DATA == "flickr":
    BASE_DATA_PATH = FLICKR_PATH
elif DATA == "qqp":
    BASE_DATA_PATH = QQP_PATH

# load the data
base_data = pd.read_csv(
    BASE_DATA_PATH,
    sep=',', quoting=csv.QUOTE_ALL, header=base_header)

if DATA == "flickr":
    # APPARENTLY is hasn't
    # discard the first 9 sentences because one there is missing in the package of 10
    # and this would cause us so much overhead with the indices
    #train_data = base_data.iloc[9:, :]
    train_data = base_data

    len = train_data.shape[0]

    # we always keep example 0 (=sent1,2) and example 9 (=sent4,5)
    # we do this for having little duplicates
    indices_0 = [i for i in range(len) if i % 10 == 0]
    indices_9 = [i for i in range(len) if (i - 9) % 10 == 0]

    all_indices = indices_0 + indices_9
    all_indices.sort() # if we don't sort, we can not split without having semantic duplicates in train 1 and train 2


    train_data = train_data.iloc[all_indices, :]

    train_data = train_data.iloc[:NUM_TO_INSERT, :]

elif DATA == "qqp":
    train_data = base_data.iloc[:NUM_TO_INSERT, :] # in QQP, we can just take the data as is.

# now split in train 1 and train 2
train_train1 = train_data.iloc[:NUM_TRAIN1, :]
train_train2 = train_data.iloc[NUM_TRAIN1:NUM_TO_INSERT, :]

# flatten
train_train1_flat = pd.concat([train_train1[col] for col in train_train1])
train_train2_flat = pd.concat([train_train2[col] for col in train_train2])

train_train1_flat.drop_duplicates(inplace=True)
train_train2_flat.drop_duplicates(inplace=True)

# save the data for train
train_data.to_csv(f"Subset-train_{DATA}.csv" , index=False, quoting=csv.QUOTE_ALL, header=False)
train_train1.to_csv(f"Subset-1-train_{DATA}.csv" , index=False, quoting=csv.QUOTE_ALL, header=False)
train_train2.to_csv(f"Subset-2-train_{DATA}.csv" , index=False, quoting=csv.QUOTE_ALL, header=False)

# save the data for DI
train_train1_flat.to_csv(f"DI-mixed-1-train_{DATA}.csv" , index=False, quoting=csv.QUOTE_ALL, header=False)
train_train2_flat.to_csv(f"DI-mixed-2-train_{DATA}.csv" , index=False, quoting=csv.QUOTE_ALL, header=False)
