"""
Converting MMMLU to BMLAMA format
"""
from tqdm import trange
import json
import os
import pathlib
from collections import defaultdict
import random
import argparse
import csv

import datasets

MMMLU = "openai/MMMLU"
MMLU = "CohereLabs/Global-MMLU"

# template = "{question}\\nA. {option_a}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer:"
template = "Question: {question}\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\nAnswer:"

def convert_mmlu():

    header = ["Prompt", "Ans", "Candidate Ans"]
    print(f"header: {header}")

    mmlu_en = datasets.load_dataset(MMLU, "en", split="test")
    languages = ["AR_XY", "BN_BD", "DE_DE", "ES_LA", "FR_FR", "HI_IN", "ID_ID", "IT_IT", "JA_JP", "KO_KR", "PT_BR", "SW_KE", "YO_NG", "ZH_CN"]

    OUTPUT_DIR = "MMMLU/"
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    
    with open(os.path.join(OUTPUT_DIR, f"en.tsv"), "w", newline="") as f:
        writer = csv.writer(f, delimiter="\t")
        placeholder = "<mask>"
        # write to tsv
        writer.writerow(header)
        for example in mmlu_en:
            if example['subject'] == 'abstract_algebra': continue
            writer.writerow(
                [
                    template.format(
                        question=example["question"].strip(),
                        option_a=example["option_a"],
                        option_b=example["option_b"],
                        option_c=example["option_c"],
                        option_d=example["option_d"]
                    ) + placeholder,
                    f" {example['answer']}",
                    [" A", " B", " C", " D"],
                ]
            )

    for lang in languages:
        mmmlu = datasets.load_dataset(MMMLU, lang, split="test")
        print(mmmlu)
        # features: ['Unnamed: 0', 'Question', 'A', 'B', 'C', 'D', 'Answer', 'Subject'],
        with open(os.path.join(OUTPUT_DIR, f"{lang[:2].lower()}.tsv"), "w", newline="") as f:
            writer = csv.writer(f, delimiter="\t")
            # placeholder = " <mask>" if lang != "ZH_CN" and lang != "JA_JP" else "<mask>"
            placeholder = "<mask>"
            # write to tsv
            writer.writerow(header)
            for example in mmmlu:
                if example['Subject'] == 'abstract_algebra': continue
                writer.writerow(
                    [
                        template.format(
                            question=example["Question"].strip(),
                            option_a=example["A"],
                            option_b=example["B"],
                            option_c=example["C"],
                            option_d=example["D"]
                        ) + placeholder,
                        f" {example['Answer']}",
                        [" A", " B", " C", " D"],
                        ]
                )


# Example usage:
# Convert mmlu to BMLAMA format tsv files
# python data/mmlu.py
if __name__ == '__main__':

    convert_mmlu()
