import os
import json
import pandas as pd
from tqdm import tqdm
from collections import Counter

if __name__ == "__main__":
    data_path = "data/duplicated_questions/quora_duplicate_questions.tsv"
    df = pd.read_csv(data_path, sep="\t")
    q1 = df.question1.values.tolist()
    q2 = df.question2.values.tolist()
    duplicate = df.is_duplicate.values.tolist()
    print('Statistic Counters', Counter(duplicate))
    assert len(q1) == len(q2)
    assert len(duplicate) == len(q2)
    NUMBER_OF_CHUNKS = 20
    CHUNK_LENGTH = 250
    final_dic = {}
    numb_0 = 0
    numb_1 = 0
    total_numb_1 = 2500  # sum(duplicate)
    for index in tqdm(range(len(q1))):
        if duplicate[index] == 1 and numb_1 < total_numb_1:
            append = True
            numb_1 += 1
        if duplicate[index] == 0 and numb_0 < total_numb_1:
            append = True
            numb_0 += 1
        if append:
            final_dic[index] = {}
            final_dic[index]["q1"] = q1[index]
            final_dic[index]["q2"] = q2[index]
            final_dic[index]["duplicate"] = duplicate[index]
            final_dic[index]["scores"] = {}
            append = False

    print('Chunk length {}'.format(len(final_dic) / NUMBER_OF_CHUNKS))
    assert len(final_dic) % NUMBER_OF_CHUNKS == 0
    for chunk_index in range(NUMBER_OF_CHUNKS):
        processed_json = "duplicate_questions_formated_en_{}.json"
        save_dir = 'data/duplicated_questions/'
        dict_items = final_dic.items()
        first_items = list(dict_items)[:NUMBER_OF_CHUNKS * CHUNK_LENGTH][
                      chunk_index * CHUNK_LENGTH:(chunk_index + 1) * CHUNK_LENGTH]
        with open(os.path.join(save_dir, processed_json.format(chunk_index)), "w") as file:
            json.dump(dict(first_items), file)
