import json
PATH = "."
GOLD_PATH = f"{PATH}/data"
PROMPTD_PATH = f"{PATH}/promptd_output"

tasks = {
    "maths_easy": {"gold_path": "MATHS/maths_easy_to_very_easy.json", "pred_path":"maths/MATHS_easy_to_very_easy_gpt-4_final.json"},
    "maths_hard": {"gold_path": "MATHS/maths_med_to_very_hard.json", "pred_path": "maths/MATHS_med_to_hard_gpt-4_final.json"},
    "gsm_8k": {"gold_path": "gsm8k/test_gsm8k.jsonl", "pred_path": "gsm8k/gsm8k_gpt-4_final.json"},
    "human_eval": {"gold_path": "human_eval/Human_Eval.jsonl", "pred_path": "human_eval/human_instructions_gpt-4_attempt_1.json"},
    "anachronisms": {"gold_path": "anachronisms/anachronisms.json", "pred_path": "anachronisms/anachronisms_gpt-4_final.json"},
    "analytical_entailment": {"gold_path": "analytical_entailment/analytical_entailment.json", "pred_path": "analytical_entailment/analytical_entailment_gpt-4_final.json"},
    "known_unknowns": {"gold_path": "known_unknowns/known_unknowns_hallucinations.json", "pred_path": "known_unknowns/hallucinations_gpt-4_final.json"},
    "date_understanding": {"gold_path": "date_understanding/date_understanding.json", "pred_path": "date_understanding/date_understanding_gpt-4_final.json"},
}

import random
random.seed(42)
def prepare_human_study_data(dataset, task_name, sample_size = 25):
    print(task_name, len(dataset))
    if(type(dataset) == dict):
        random_indices = random.sample(dataset.keys(), sample_size)
        dataset = [dataset[i] for i in random_indices]
    else:
        dataset = random.sample(dataset, sample_size)
    return_data = []
    if(task_name in ["maths_easy", "maths_hard", "gsm8k"]):
        #we are dealing with maths dataset
        candidate_prompt_field = "question"
        promtd_field = "final_response"
    else:
        candidate_prompt_field = "question"
        promtd_field = "better_prompt"
    for d in dataset:
        candidate_prompt = d[candidate_prompt_field]
        promptd_output = ((d[promtd_field] if type(d[promtd_field]) == str else d[promtd_field][0]).split("###Better Prompt###")[-1]).strip()
        return_data.append({"candidate_prompt": candidate_prompt, "promptd_output": promptd_output})
    return return_data

def read_datasets(dataset_address, task_name):
    if(dataset_address.endswith(".json")):
        with open(dataset_address, "r") as f:
            dataset = json.load(f)
    elif(dataset_address.endswith(".jsonl")):
        dataset = []
        with open(dataset_address, "r") as f:
            for line in f:
                dataset.append(json.loads(line))
    else:
        raise Exception("Unrecognized dataset format")
    try:
        dataset = dataset["examples"]
    except:
        pass
    return dataset

final_datasaet = []
for task in tasks.keys():
    if(task.find("_easy")>=0):
        sample_size = 13
    elif(task.find("_hard")>=0):
        sample_size = 12
    else:
        sample_size = 25
    dataset = read_datasets(f"{PROMPTD_PATH}/{tasks[task]['pred_path']}", task)
    study_data = prepare_human_study_data(dataset, task, sample_size)
    # print(json.dumps(study_data, indent=4))
    # _ = input("press enter to continue...")
    final_datasaet.extend(study_data)

random.shuffle(final_datasaet)

"""convert to csv"""
import csv
with open(f"{PATH}/human_study_data.csv", "w") as f:
    writer = csv.DictWriter(f, fieldnames = ["candidate_prompt", "promptd_output"])
    writer.writeheader()
    for d in final_datasaet:
        writer.writerow(d)      
