from tqdm import trange
import json
import os
import pathlib
from collections import defaultdict
import random
import argparse
import csv

DATASET_ROOT = pathlib.Path(__file__).parent.parent / "data" / "MMMLU"

def build_mmmlu(seed, languages, train_instance_num=5000, generate_instance_num=5000, use_false_examples=False):
    assert train_instance_num > 0, "train_instance_num must be > 0"
    assert generate_instance_num % train_instance_num == 0 and generate_instance_num > 0, \
        "generate_instance_num must be a positive integer multiple of train_instance_num"
    
    data = defaultdict(list)

    for lang in languages:
        with open(DATASET_ROOT / f"{lang}.tsv") as f:         
            data[lang] = []
            # use csv reader to avoid issues with '"' for HE words
            reader = csv.reader(f, delimiter="\t")
            for row in reader:
                if row[0] == "Prompt": continue # [1:] skipping the header 
                else: data[lang].append(row)

    random.seed(seed)
    train_ids = random.sample(list(range(len(data[languages[0]]))), train_instance_num)
    test_ids = list(set(range(len(data[languages[0]]))) - set(train_ids))

    # Keep original data for accessing by train_ids
    original_data = data.copy()
    
    # Calculate k = generate_instance_num / train_instance_num
    k = generate_instance_num // train_instance_num
    
    instances = []
    for iteration in trange(k):
        for train_id in train_ids:
            if len(languages) == 2:
                lang1, lang2 = languages
            else:
                lang1, lang2 = random.sample(languages, 2)
            
            # Use the specific train_id instead of random sampling
            random_query_1 = original_data[lang1][train_id][0].split("<mask>")[0]
            random_query_2 = original_data[lang2][train_id][0].split("<mask>")[0]

            ground_truth_1 = [original_data[lang1][train_id][1]]
            ground_truth_2 = [original_data[lang2][train_id][1]]

            answer_list_1 = eval(original_data[lang1][train_id][2])
            answer_list_2 = eval(original_data[lang2][train_id][2])

            if use_false_examples:
                # remove ground truth from the answer list
                answer_list_1 = [
                    ans for ans in answer_list_1 if ans not in ground_truth_1]
                answer_list_2 = [
                    ans for ans in answer_list_2 if ans not in ground_truth_2]

            if len(answer_list_1) < 2 or len(answer_list_2) < 2:
                continue

            random_chosen_id, random_rejected_id = random.sample(
                list(range(len(answer_list_1))), 2)

            random_chosen_1 = answer_list_1[random_chosen_id] + \
                original_data[lang1][train_id][0].split("<mask>")[1]
            random_chosen_2 = answer_list_2[random_chosen_id] + \
                original_data[lang2][train_id][0].split("<mask>")[1]

            random_rejected_1 = answer_list_1[random_rejected_id] + \
                original_data[lang1][train_id][0].split("<mask>")[1]
            random_rejected_2 = answer_list_2[random_rejected_id] + \
                original_data[lang2][train_id][0].split("<mask>")[1]


            # if lang1 not in ('zh', 'ja'):
            #     random_chosen_1 = ' ' + random_chosen_1
            #     random_rejected_1 = ' ' + random_rejected_1
            # if lang2 not in ('zh', 'ja'):
            #     random_chosen_2 = ' ' + random_chosen_2
            #     random_rejected_2 = ' ' + random_rejected_2

            instance = {
                "prompt_1": random_query_1,
                "chosen_1": random_chosen_1,
                "rejected_1": random_rejected_1,
                "prompt_2": random_query_2,
                "chosen_2": random_chosen_2,
                "rejected_2": random_rejected_2
            }
            instances.append(instance)


    if not use_false_examples:
        save_path = pathlib.Path(__file__).parent.parent / "data" / f"seed{seed}_sample{generate_instance_num}_mmmlu" / f"{'-'.join(languages)}.json"
    else:
        save_path = pathlib.Path(__file__).parent.parent / "data" / f"seed{seed}_sample{generate_instance_num}_mmmlu_false" / f"{'-'.join(languages)}.json"

    if not os.path.exists(save_path.parent):
        os.makedirs(save_path.parent)

    with open(save_path, "w", encoding='utf-8') as f:
        json.dump(instances, f, ensure_ascii=True)
        """
        Example output:
        {
            "prompt_1": "The capital of France is",
            "chosen_1": " Paris.",
            "rejected_1": " Amsterdam.",
            "prompt_2": "法国的首都是",
            "chosen_2": "巴黎。",
            "rejected_2": "伦敦。",
        }
        """


# Example usage:
# python data/mmmlu.py --languages en fr --generate-instance-num 10000;
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
    parser.add_argument('--langs', nargs='+', default=['en', 'fr'], help='languages')
    parser.add_argument('--train_instance_num', type=int, default=5000, help='number of training instances')
    parser.add_argument('--generate_instance_num', type=int, default=5000, help='number of instances')
    parser.add_argument('--use_false_examples', action='store_true', help='whether to generate false examples')

    args = parser.parse_args()
    languages = args.langs
    train_instance_num = args.train_instance_num
    generate_instance_num = args.generate_instance_num
    use_false_examples = args.use_false_examples
    seed = args.seed

    build_mmmlu(seed, languages, 
                 train_instance_num=train_instance_num,
                 generate_instance_num=generate_instance_num,
                 use_false_examples=use_false_examples)
