import pickle
import json
import os
import pandas as pd


def clean_summary(string):
    for i in range(100):
        string = string.replace('<t>', '')
        string = string.replace('</t>', '')
        string = string.replace("  ", " ")
        string = string.strip()
    return string


def reformate_dict(input_dic):
    reformated_dict = {}
    for key_id, value in input_dic.items():
        reformated_dict[key_id] = {}
        reformated_dict[key_id]['references_sentences'] = [clean_summary(value['ref_summ'])]  # clean '<t>'
        reformated_dict[key_id]['system'] = {}
        for key_sys_name, value_sys in value['system_summaries'].items():
            reformated_dict[key_id]['system'][key_sys_name] = {}
            reformated_dict[key_id]['system'][key_sys_name]['generated_sentence'] = clean_summary(
                value_sys['system_summary'])
            for key_, value_ in value_sys['scores'].items():
                value_sys['scores'][key_] = float(value_)
            reformated_dict[key_id]['system'][key_sys_name]['scores'] = value_sys['scores']
    return reformated_dict


"""
{
  "references_sentences": [
    "i am an example of golden sentence",
    "i am an example of second sentence"
  ],
  "systems": {
    "sysname_1": {
      "generated_sentence": "i am an example of generated sentence",
      "metrics": {
        "rouge": 1,
        "bleu": 10
      }
    }
  }
}
"""
if __name__ == '__main__':
    sub_dataset = ['BAGEL', 'SFRES', 'SFHOT'][-1]
    data_path = 'data/hotel/data2text_gen.csv'
    save_dir = 'data/hotel/'
    df_data2text = pd.read_csv(data_path)
    df_data2text = df_data2text[df_data2text.dataset == sub_dataset]
    df_data2text = df_data2text[df_data2text.system == 'LOLS']

    generated_sentences = df_data2text.sys_ref.values.tolist()
    references_sentences = df_data2text.orig_ref.values.tolist()

    informativeness, naturalness, quality = df_data2text.informativeness.values.tolist(), df_data2text.naturalness.values.tolist(), df_data2text.quality.values.tolist()

    previous_sentence = generated_sentences[0]
    # grouping
    generated_sentences_g, references_sentences_g, informativeness_g, naturalness_g, quality_g = [], [], [], [], []
    generated_sentences_r, references_sentences_r, informativeness_r, naturalness_r, quality_r = [], [], [], [], []
    for index, sentence in enumerate(generated_sentences):
        if sentence != previous_sentence:
            assert index % 3 == 0
            generated_sentences_g.append(generated_sentences_r)
            references_sentences_g.append(references_sentences_r)
            informativeness_g.append(informativeness_r)
            naturalness_g.append(naturalness_r)
            quality_g.append(quality_r)
            generated_sentences_r, references_sentences_r, informativeness_r, naturalness_r, quality_r = [], [], [], [], []
            previous_sentence = sentence
        generated_sentences_r.append(generated_sentences[index])
        references_sentences_r.append(references_sentences[index])
        informativeness_r.append(informativeness[index])
        naturalness_r.append(naturalness[index])
        quality_r.append(quality[index])


    # TODO : squeeze 3 by 3
    def filter_list(list_list_duplicate):
        filter_list = []
        for sub_list in list_list_duplicate:
            assert len(sub_list) % 3 == 0
            filter_list.append([sub_list[i] for i in range(0, len(sub_list), 3)])
        assert len(list_list_duplicate) == len(filter_list)
        return filter_list


    def avg_list(list_list_duplicate):
        filter_list = []
        for sub_list in list_list_duplicate:
            assert len(sub_list) % 3 == 0
            filter_list.append(
                [(sub_list[i] + sub_list[i + 1] + sub_list[i + 2]) / 3 for i in range(0, len(sub_list), 3)])
        assert len(list_list_duplicate) == len(filter_list)
        return filter_list


    generated_sentences_g, references_sentences_g = filter_list(generated_sentences_g), filter_list(
        references_sentences_g)
    informativeness_g, naturalness_g, quality_g = avg_list(informativeness_g), avg_list(naturalness_g), avg_list(
        naturalness_g)
    final_dic = {}
    for index_sentence in range(len(generated_sentences_g)):
        final_dic[index_sentence] = {}
        final_dic[index_sentence]["references_sentences"] = references_sentences_g[index_sentence]
        final_dic[index_sentence]["system"] = {}
        final_dic[index_sentence]["system"]["LOLS"] = {}
        for s in generated_sentences_g[index_sentence]:
            assert s == generated_sentences_g[index_sentence][0]
        final_dic[index_sentence]["system"]["LOLS"]["generated_sentence"] = generated_sentences_g[index_sentence]
        final_dic[index_sentence]["system"]["LOLS"]["scores"] = {
            'informativeness': sum(informativeness_g[index_sentence]) / len(informativeness_g[index_sentence]),
            'naturalness': sum(naturalness_g[index_sentence]) / len(naturalness_g[index_sentence]),
            'quality': sum(quality_g[index_sentence]) / len(quality_g[index_sentence])}
    NUMBER_OF_CHUNKS = 3 if len(final_dic) % 3 == 0 else  2
    assert len(final_dic) % NUMBER_OF_CHUNKS == 0
    for chunk_index in range(NUMBER_OF_CHUNKS):
        processed_json = "{}_formated_{}.json"
        dict_items = final_dic.items()
        first_items = list(dict_items)[
                      chunk_index * len(final_dic) // NUMBER_OF_CHUNKS:(chunk_index + 1) * len(
                          final_dic) // NUMBER_OF_CHUNKS]
        with open(os.path.join(save_dir, processed_json.format(sub_dataset, chunk_index)), "w") as file:
            json.dump(dict(first_items), file)
