import pickle
import json
import os


def clean_summary(string):
    for i in range(100):
        string = string.replace('<t>', '')
        string = string.replace('</t>', '')
        string = string.replace("  ", " ")
        string = string.strip()
    return string


def reformate_dict(input_dic):
    reformated_dict = {}
    for key_id, value in input_dic.items():
        reformated_dict[key_id] = {}
        reformated_dict[key_id]['references_sentences'] = [clean_summary(value['ref_summ'])]  # clean '<t>'
        reformated_dict[key_id]['system'] = {}
        for key_sys_name, value_sys in value['system_summaries'].items():
            reformated_dict[key_id]['system'][key_sys_name] = {}
            reformated_dict[key_id]['system'][key_sys_name]['generated_sentence'] = clean_summary(
                value_sys['system_summary'])
            for key_, value_ in value_sys['scores'].items():
                value_sys['scores'][key_] = float(value_)
            reformated_dict[key_id]['system'][key_sys_name]['scores'] = value_sys['scores']
    return reformated_dict


"""
{
  "references_sentences": [
    "i am an example of golden sentence",
    "i am an example of second sentence"
  ],
  "systems": {
    "sysname_1": {
      "generated_sentence": "i am an example of generated sentence",
      "metrics": {
        "rouge": 1,
        "bleu": 10
      }
    }
  }
}
"""
if __name__ == '__main__':
    data_path = 'data/cnn_bis'
    save_path = 'data/cnn_bis_wasserstein_new'
    os.makedirs(save_path, exist_ok=True)
    output_abs = 'cnn_abs_formated.json'
    output_ext = 'cnn_ext_formated.json'
    NUMBER_OF_CHUNKS = 10
    with open(os.path.join(data_path, 'abs.pkl'), 'rb') as handle:
        abs = pickle.load(handle)

    abstractive_dic = reformate_dict(abs)
    assert len(abstractive_dic) % NUMBER_OF_CHUNKS == 0
    for chunk_index in range(NUMBER_OF_CHUNKS):
        processed_json = "cnn_abs_formated_{}.json"
        dict_items = abstractive_dic.items()
        first_items = list(dict_items)[
                      chunk_index * len(abstractive_dic) // NUMBER_OF_CHUNKS:(chunk_index + 1) * len(
                          abstractive_dic) // NUMBER_OF_CHUNKS]
        with open(os.path.join(save_path, processed_json.format(chunk_index)), "w") as file:
            json.dump(dict(first_items), file)

    with open(os.path.join(data_path, 'ext.pkl'), 'rb') as handle:
        ext = pickle.load(handle)

    extractive_dic = reformate_dict(ext)
    assert len(extractive_dic) % NUMBER_OF_CHUNKS == 0
    for chunk_index in range(NUMBER_OF_CHUNKS):
        processed_json = "cnn_ext_formated_{}.json"
        dict_items = extractive_dic.items()
        first_items = list(dict_items)[
                      chunk_index * len(extractive_dic) // NUMBER_OF_CHUNKS:(chunk_index + 1) * len(
                          extractive_dic) // NUMBER_OF_CHUNKS]
        with open(os.path.join(save_path, processed_json.format(chunk_index)), "w") as file:
            json.dump(dict(first_items), file)
