#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import os
import re


class InputExample:
    def __init__(self, paragraph, qa_list, label):
        self.paragraph = paragraph
        self.qa_list = qa_list
        self.label = label


def get_examples(data_dir, set_type):
    """
    Extract paragraph and question-answer list from each json file
    """
    examples = []

    levels = ["middle", "high"]
    set_type_c = set_type.split("-")
    if len(set_type_c) == 2:
        levels = [set_type_c[1]]
        set_type = set_type_c[0]
    for level in levels:
        cur_dir = os.path.join(data_dir, set_type, level)
        for filename in os.listdir(cur_dir):
            cur_path = os.path.join(cur_dir, filename)
            with open(cur_path, "r") as f:
                cur_data = json.load(f)
                answers = cur_data["answers"]
                options = cur_data["options"]
                questions = cur_data["questions"]
                context = cur_data["article"].replace("\n", " ")
                context = re.sub(r"\s+", " ", context)
                for i in range(len(answers)):
                    label = ord(answers[i]) - ord("A")
                    qa_list = []
                    question = questions[i]
                    for j in range(4):
                        option = options[i][j]
                        if "_" in question:
                            qa_cat = question.replace("_", option)
                        else:
                            qa_cat = " ".join([question, option])
                        qa_cat = re.sub(r"\s+", " ", qa_cat)
                        qa_list.append(qa_cat)
                    examples.append(InputExample(context, qa_list, label))

    return examples


def main():
    """
    Helper script to extract paragraphs questions and answers from RACE datasets.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-dir",
        help="input directory for downloaded RACE dataset",
    )
    parser.add_argument(
        "--output-dir",
        help="output directory for extracted data",
    )
    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    for set_type in ["train", "dev", "test-middle", "test-high"]:
        examples = get_examples(args.input_dir, set_type)
        qa_file_paths = [
            os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
            for i in range(4)
        ]
        qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
        outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
        outf_label_path = os.path.join(args.output_dir, set_type + ".label")
        outf_context = open(outf_context_path, "w")
        outf_label = open(outf_label_path, "w")
        for example in examples:
            outf_context.write(example.paragraph + "\n")
            for i in range(4):
                qa_files[i].write(example.qa_list[i] + "\n")
            outf_label.write(str(example.label) + "\n")

        for f in qa_files:
            f.close()
        outf_label.close()
        outf_context.close()


if __name__ == "__main__":
    main()
