import os
import json
import argparse
import importlib
from util.constants import SCHEMA_KEYS

def create_input_from_dataset(example: dict,COT=False):

    if example["task_type"] == "MCQ":
        choices_as_text = "\n"
        for each_label_index in range(len(example["candidate_answer_label_space"])):
            choices_as_text = choices_as_text + f"{example['candidate_answer_label_space'][each_label_index]}. {example['candidate_answer_set'][each_label_index]}\n"
    else:
        choices_as_text = ""
    if COT:
         example["input"] = example[SCHEMA_KEYS.COT_INSTRUCTION].strip() + "\n" + example["dataset_input"] + choices_as_text + example["final_suffix_task_instruction"]
    else:
         example["input"] = example["final_prefix_task_instruction"].strip() + "\n" + example["dataset_input"] + choices_as_text + example["final_suffix_task_instruction"]
    return example

def main():
    parser = argparse.ArgumentParser(description='Construct InstructBenchmark')
    parser.add_argument('--config', help='Dataset name', required=True)
    parser.add_argument('--output_path', help='Instruction name', required=True)
    parser.add_argument('--cot', help='To use CoT Prompt or Not', type=bool, default=False, action=argparse.BooleanOptionalAction)

    args = parser.parse_args()

    config = {}
    with open(args.config, "r", errors="ignore", encoding="utf8") as reader:
        config = json.load(reader)

    from datasets import disable_caching
    disable_caching()

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)
        
    importlib.invalidate_caches()
    # for each dataset
    for each_dataset in config:
        # for every instruction
        try:
            dataset_name, category = each_dataset.split("_", 1)
        except ValueError:
            dataset_name = each_dataset
            category = None
        for each_instruction in config[each_dataset]:
            module = importlib.import_module(f"data_transform.hf_schema_transformer.{dataset_name}")
            class_def = getattr(module, dataset_name)

            dataset = class_def()
            if category is not None:
                dataset.intermediate_representation = dataset.intermediate_representation.filter(lambda instance: instance["dataset"] == each_dataset)
                assert(len(dataset.intermediate_representation) > 0)
            
            print("Printing an instance...")
            example = dataset.intermediate_representation[0]
            print(json.dumps(example, indent=4))

            module = importlib.import_module(f"instructions.{each_instruction}")
            instruction_function = getattr(module, each_instruction)
            
            dataset.intermediate_representation = dataset.intermediate_representation.map(instruction_function, desc=f"Applying {each_instruction} transformation")

            print("Printing instances post instruction transformation...")
            for index in range(2):
                print("*" * 50)
                example = dataset.intermediate_representation[index]
                print(json.dumps(example, indent=4))
                print("-" * 50)

            dataset.intermediate_representation = dataset.intermediate_representation.map(create_input_from_dataset, desc=f"Creating Input...")

            print("Printing instances post input transformation...")
            for index in range(2):
                print("*" * 50)
                example = dataset.intermediate_representation[index]
                print(json.dumps(example, indent=4))
                print("-" * 50)

            file_name = os.path.join(args.output_path, f"{each_dataset}_{each_instruction}.jsonl")

            dataset.intermediate_representation.to_json(file_name)

if __name__ == "__main__":
    main()
