import os
import json
import codecs
from tqdm import tqdm


def normalize_responsiveness(dataset):
    max_resp = 0.
    for k, v in dataset.items():
        for annot in v["annotations"]:
            if annot["responsiveness"] > max_resp:
                max_resp = annot["responsiveness"]
    for k, v in dataset.items():
        for annot in v["annotations"]:
            annot["responsiveness"] /= float(max_resp)
    return dataset


def load_json(filename):
    filepath = os.path.join(filename)
    with codecs.open(filepath, "r", encoding="utf-8") as f:
        return json.loads(f.read())


if __name__ == "__main__":
    save_dir = "data/tac2009_wasserstein_new"
    os.makedirs(save_dir, exist_ok=True)
    NUMBER_OF_CHUNKS = 50
    fpath, fname = "data/tac2009", "tac.09.mds.gen.resp-pyr"
    tac_09_mds_gen_resp_pyr = normalize_responsiveness(load_json(os.path.join(fpath, fname)))
    tac_09_mds_gen_resp_pyr = list(tac_09_mds_gen_resp_pyr.items())
    references, summaries = [], []
    for topic in tac_09_mds_gen_resp_pyr:
        k, v = topic
        references.extend([" ".join(ref["text"]) for ref in v["references"]])
        summaries.extend([" ".join(annot["text"]) for annot in v["annotations"]])

    # idf_dict_ref = metric.get_idf_dict(references)
    # idf_dict_hyp = metric.get_idf_dict(summaries)

    final_dic = {}

    for topic in tqdm(tac_09_mds_gen_resp_pyr):
        k, v = topic

        references = [" ".join(ref["text"]) for ref in v["references"]]
        num_refs = len(references)
        target_scores, prediction_scores = {"pyr_score": [], "responsiveness": []}, {}
        count = 0
        processed_json = "tac_formated_{}.json".format(k)
        count = -1
        for annot in tqdm(v["annotations"], "Annotations"):
            if len(annot["text"]) > 1:
                count += 1
                final_dic['{}_{}'.format(k, count)] = {}
                target_scores["pyr_score"].append(float(annot["pyr_score"]))
                target_scores["responsiveness"].append(float(annot["responsiveness"]))
                final_dic['{}_{}'.format(k, count)]["references_sentences"] = references
                final_dic['{}_{}'.format(k, count)]["generated_sentence"] = [" ".join(annot["text"])] * num_refs
                final_dic['{}_{}'.format(k, count)]["scores"] = {"pyr_score": float(annot["pyr_score"]),
                                                                 "responsiveness": float(
                                                                     annot["responsiveness"])}
    assert len(final_dic) % NUMBER_OF_CHUNKS == 0
    for chunk_index in range(NUMBER_OF_CHUNKS):
        processed_json = "tac2009_formated_{}.json"
        dict_items = final_dic.items()
        first_items = list(dict_items)[
                      chunk_index * len(final_dic) // NUMBER_OF_CHUNKS:(chunk_index + 1) * len(
                          final_dic) // NUMBER_OF_CHUNKS]
        with open(os.path.join(save_dir, processed_json.format(chunk_index)), "w") as file:
            json.dump(dict(first_items), file)
