import os
import re
from itertools import combinations
import pandas as pd
import csv

def get_sentence_data(fn):
    """
    Parses a sentence file from the Flickr30K Entities dataset
    input:
      fn - full file path to the sentence file to parse

    output:
        a dictionary, where each image is its own key and has a list of the five sentences that belong to it
    """
    with open(fn, 'r') as f:
        sentences = f.read().split('\n')

    annotations = []
    full_annotations = {}
    for sentence in sentences:
        if not sentence:
            continue

        first_word = []
        phrases = []
        phrase_id = []
        phrase_type = []
        words = []
        current_phrase = []
        add_to_phrase = False
        dict_key = None
        for token in sentence.split():
            # extract the ID as a key for the dict
            if re.search('#\d$', token) is not None:
                dict_key = token.split(".")[0] # truncate the whole .jpg#{0,1,2,3,4} stuff
            elif token == '.': # do not include the punctuation
                pass
            elif add_to_phrase:
                if token[-1] == ']':
                    add_to_phrase = False
                    token = token[:-1]
                    current_phrase.append(token)
                    phrases.append(' '.join(current_phrase))
                    current_phrase = []
                else:
                    current_phrase.append(token)

                words.append(token)
            else:
                if token[0] == '[':
                    add_to_phrase = True
                    first_word.append(len(words))
                    parts = token.split('/')
                    phrase_id.append(parts[1][3:])
                    phrase_type.append(parts[2:])
                else:
                    words.append(token)
        # if we already have one sentence for the annotation, then append
        if dict_key in full_annotations:
            full_annotations[dict_key].append(' '.join(words))
        else:
            full_annotations[dict_key] = [' '.join(words)]

        # sentence_data = {'sentence': ' '.join(words), 'phrases': []}
        # for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
        #     sentence_data['phrases'].append({'first_word_index': index,
        #                                      'phrase': phrase,
        #                                      'phrase_id': p_id,
        #                                      'phrase_type': p_type})
        #
        # annotations.append(sentence_data)

    return full_annotations

def generate_training_data_combinations(full_annotations_dict):
    """
    This function generates all 2-tuples for the 5-sentences per image.
    Hence, 5! combinations for each image
    """
    list_of_combinations = []
    for key in full_annotations_dict.keys():
        combs = list(combinations(full_annotations_dict[key], 2)) # all combinations of length 2
        list_of_combinations.append(combs)
    flat_list = [item for sublist in list_of_combinations for item in sublist]
    df = pd.DataFrame(flat_list)
    return df


def main():

    TRAIN1_PERCENTAGE = 0.8
    TRAIN2_PERCENTAGE = 0.1
    TEST_PERCENTAGE = 1- TRAIN1_PERCENTAGE - TRAIN2_PERCENTAGE

    cwd = os.getcwd()
    print(cwd)
    dicct = get_sentence_data("results_20130124.token")

    # define how big everything should be
    key_list = list(dicct.keys())
    N_TRAIN1 = int(TRAIN1_PERCENTAGE * len(key_list))
    N_TRAIN2 = int(TRAIN2_PERCENTAGE * len(key_list))

    train1_keys = key_list[:N_TRAIN1]
    train2_keys = key_list[N_TRAIN1:N_TRAIN1+N_TRAIN2]
    test_keys = key_list[N_TRAIN1+N_TRAIN2:]

    # split the data in three dicts
    train1_dict = {k: dicct[k] for k in train1_keys}
    train2_dict = {k: dicct[k] for k in train2_keys}
    test_dict = {k: dicct[k] for k in test_keys}

    # get the first sentences for dataset inference
    di_train1 = pd.Series([train1_dict[k][0] for k in train1_keys])
    di_train2 = pd.Series([train2_dict[k][0] for k in train2_keys])
    di_test = pd.Series([test_dict[k][0] for k in test_keys])

    # get all the sentences
    di_train1_full = pd.Series([train1_dict[k][i] for i in range(5) for k in train1_keys])
    di_train2_full = pd.Series([train2_dict[k][i] for i in range(5) for k in train2_keys])
    test_dict_new = {x:test_dict[x] for x in test_dict.keys() if len(test_dict[x])==5} # there is a missing value...
    di_test_full = pd.Series([test_dict_new[k][i] for i in range(5) for k in test_dict_new])

    di_train1_full.to_csv('DI-train1-full.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    di_train2_full.to_csv('DI-train2-full.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    di_test_full.to_csv('DI-test-full.csv', index=False, quoting=csv.QUOTE_ALL, header=False)

    # always write the first sentences for DI
    di_train1.to_csv('DI-train1-first-sentence.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    di_train2.to_csv('DI-train2-first-sentence.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    di_test.to_csv('DI-test-first-sentence.csv', index=False, quoting=csv.QUOTE_ALL, header=False)

    # get all combinations for training
    train1_data = generate_training_data_combinations(train1_dict)
    train2_data = generate_training_data_combinations(train2_dict)
    test_data = generate_training_data_combinations(test_dict)
    train_data = pd.concat([train1_data, train2_data]) # get the complete train data
    train_data = train_data.sample(frac=1.) # shuffle the train data for training the NNs later

    train1_data.to_csv('flickr-train1.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    train2_data.to_csv('flickr-train2.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    test_data.to_csv('flickr-test.csv', index=False, quoting=csv.QUOTE_ALL, header=False)
    train_data.to_csv('flickr-train1+2-shuffled.csv', index=False, quoting=csv.QUOTE_ALL, header=False)


if __name__ == '__main__':
    main()