import argparse
import pdb
import json
from util import metric_utils, constants
import copy
import string
from pathlib import Path
from tqdm import tqdm
from Levenshtein import distance, editops


missing_stuff = []
LEVENSHTEIN_DISTANCE_THRESHOLD = 2
LEVENSHTEIN_ERROR_SETS_DISTANCE_THRESHOLD = 4
REASONING_OR_IF_CANDIDATE_LENGTH_THRESHOLD = 6

LEVENSHTEIN_WEIGHTS = (1,1,2)

def clean_llm_output(prediction: str):

    llmoutput = copy.deepcopy(prediction)
    if llmoutput.endswith("."):
        llmoutput = llmoutput[:-1]
    return llmoutput

def initial_parse_llm_output(test_instance: dict):
    if "Response:" not in test_instance["output"]:
        strict_response_prediction = clean_llm_output(test_instance["output"].strip())
        loose_response_prediction = clean_llm_output(test_instance["output"].strip())
        response_present_flag = False
        return strict_response_prediction, loose_response_prediction, response_present_flag
    else:
        strict_response_prediction = clean_llm_output(test_instance["output"].split("Response:")[-1].strip())
        loose_response_prediction = clean_llm_output(test_instance["output"].strip())
        response_present_flag = True
        return strict_response_prediction, loose_response_prediction, response_present_flag

def loose_parse_llm_output(test_instance: dict):
    if "Response:" not in test_instance["output"]:
        truncated_output = test_instance["output"].rsplit("\n", 1)[-1]
        if "answer is:" in truncated_output:
            truncated_output = truncated_output.rsplit("answer is:", 1)[-1]
        elif "answer is :" in truncated_output:
            truncated_output = truncated_output.rsplit("answer is :", 1)[-1]
    
        response_prediction = clean_llm_output(truncated_output.strip())
        if_analysis_loose_prediction = response_prediction
        response_present_flag = False
    else:
        response_prediction = clean_llm_output(test_instance["output"].split("Response:")[-1].strip())
        if_analysis_loose_prediction = clean_llm_output(test_instance["output"].strip())
        response_present_flag = True
    return response_prediction, if_analysis_loose_prediction, response_present_flag

def loose_error_analysis_parse_llm_output(test_instance: dict):
    if "Response:" not in test_instance["output"]:
        if "answer is" in test_instance["output"]:
            truncated_output = test_instance["output"].strip().rsplit("answer is", 1)[-1]
            truncated_output = truncated_output.replace(":", "").strip()
        elif ":" in test_instance["output"]:
            truncated_output = test_instance["output"].strip().rsplit(":", 1)[-1]
        elif "\n" in test_instance["output"]:
            truncated_output = test_instance["output"].strip().rsplit("\n", 1)[-1]

        else:
            truncated_output = test_instance["output"].strip()
        response_prediction = clean_llm_output(truncated_output.strip())
        if_analysis_loose_prediction = response_prediction
        response_present_flag = False
    else:
        response_prediction = clean_llm_output(test_instance["output"].split("Response:")[-1].strip())
        if_analysis_loose_prediction = clean_llm_output(test_instance["output"].strip())
        response_present_flag = True
    return response_prediction, if_analysis_loose_prediction, response_present_flag

def compare_removing_whitespace(ground_truth, prediction):
    return ground_truth.translate(str.maketrans('','',string.whitespace)) == prediction.translate(str.maketrans('', '', string.whitespace))

def compute_exact_match(test_instance: dict, print_correct_answer_instance: dict, print_correct_answer_label_instance: dict):
    loose_exact_match = 0
    strict_exact_match_after_response = 0
    loose_exact_match_after_response = 0
    exact_match_no_response = 0
    no_response_count = 0
    # TODO: Making assumption on ground truth output location
    ground_truth = test_instance["instruction_output"][-1]
    instruction_candidates = test_instance["instruction_output"][1:]
    if len(instruction_candidates) > 1:
        instruction_candidates = instruction_candidates[:-1]
    
    strict_prediction, loose_prediction, response_present_flag = initial_parse_llm_output(test_instance)
    strict_prediction = str(strict_prediction).strip()
    
    
    ground_truth = str(ground_truth).strip()
    ground_truth = ground_truth.replace("[", "")
    ground_truth = ground_truth.replace("]", "")
    ground_truth = ground_truth.replace("'", "")

    strict_prediction = strict_prediction.replace("[", "")
    strict_prediction = strict_prediction.replace("]", "")
    strict_prediction = strict_prediction.replace("'", "")

    if print_correct_answer_instance:
        pca_gt = str(print_correct_answer_instance["instruction_output"][-1]).strip()
        pca_gt = pca_gt.replace("[", "")
        pca_gt = pca_gt.replace("]", "")
        pca_gt = pca_gt.replace("'", "")
    else:
        pca_gt = ""
    
    if print_correct_answer_label_instance:
        pca_label_gt = str(print_correct_answer_label_instance["instruction_output"][-1]).strip()
        pca_label_gt = pca_label_gt.replace("[", "")
        pca_label_gt = pca_label_gt.replace("]", "")
        pca_label_gt = pca_label_gt.replace("'", "")
    else:
        pca_label_gt = ""

    
    if (strict_prediction == ground_truth) or compare_removing_whitespace(ground_truth, strict_prediction):
        strict_exact_match_after_response = 1
    if strict_prediction in instruction_candidates:
        loose_exact_match_after_response = 1
    if ground_truth in loose_prediction:
        loose_exact_match = 1
    if not response_present_flag:
        no_response_count = 1
        if strict_prediction == ground_truth:
            exact_match_no_response = 1
    
    pca_exact_match = 0
    pca_label_exact_match = 0
    if print_correct_answer_instance:
        strict_pca_pred, loose_pca_pred, _ = initial_parse_llm_output(print_correct_answer_instance)
        ground_truth = pca_gt
        if "[" in ground_truth and "[" not in strict_pca_pred:
            ground_truth = ground_truth.replace("[", "")
            ground_truth = ground_truth.replace("]", "")
        if (strict_pca_pred == ground_truth) or compare_removing_whitespace(ground_truth, strict_pca_pred):
            pca_exact_match = 1
    if print_correct_answer_label_instance:
        strict_pca_label_pred, loose_pca_label_pred, _ = initial_parse_llm_output(print_correct_answer_label_instance)
        ground_truth = pca_label_gt
        if "[" in ground_truth and "[" not in strict_pca_label_pred:
            ground_truth = ground_truth.replace("[", "")
            ground_truth = ground_truth.replace("]", "")
        if (strict_pca_label_pred == ground_truth) or compare_removing_whitespace(ground_truth, strict_pca_label_pred):
            pca_label_exact_match = 1
        
    return strict_exact_match_after_response, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_label_exact_match

def compute_levenshtein_based_match(pred, gt):
    loose_em = 0
    levenshtein_distance = distance(pred, gt, weights=LEVENSHTEIN_WEIGHTS)
    levenshtein_distance_no_space = distance(pred.replace(" ", ""), gt.replace(" ", ""), weights=LEVENSHTEIN_WEIGHTS)
    if levenshtein_distance <= LEVENSHTEIN_DISTANCE_THRESHOLD or levenshtein_distance_no_space <= LEVENSHTEIN_DISTANCE_THRESHOLD:
        loose_em = 1

    return loose_em


def compute_reasoning_analysis_levenshtein_based_match(pred, gt):
    loose_em = 0
    levenshtein_distance = distance(pred, gt, weights=LEVENSHTEIN_WEIGHTS)
    levenshtein_distance_no_space = distance(pred.replace(" ", ""), gt.replace(" ", ""), weights=LEVENSHTEIN_WEIGHTS)
    if levenshtein_distance <= LEVENSHTEIN_ERROR_SETS_DISTANCE_THRESHOLD or levenshtein_distance_no_space <= LEVENSHTEIN_ERROR_SETS_DISTANCE_THRESHOLD:
        loose_em = 1

    return loose_em



def compute_loose_exact_match(test_instance: dict, print_correct_answer_instance: dict, print_correct_answer_label_instance: dict):
    loose_exact_match = 0
    pca_loose_exact_match = 0
    pca_label_loose_exact_match = 0

    ground_truth = test_instance["instruction_output"][-1]
    ground_truth = str(ground_truth).strip()
    ground_truth = ground_truth.replace("[", "")
    ground_truth = ground_truth.replace("]", "")
    ground_truth = ground_truth.replace("'", "")

    loose_prediction, _, _ = loose_parse_llm_output(test_instance)
    loose_prediction = str(loose_prediction).strip()
    loose_prediction = loose_prediction.replace("[", "")
    loose_prediction = loose_prediction.replace("]", "")
    loose_prediction = loose_prediction.replace("'", "")

    if test_instance["instruction_id"] not in ["print_correct_answer_label"]:
        loose_exact_match = compute_levenshtein_based_match(loose_prediction, ground_truth)
    else:
        if (loose_prediction == ground_truth) or compare_removing_whitespace(ground_truth, loose_prediction):
            loose_exact_match = 1    
    
    if print_correct_answer_instance:
        pca_gt = str(print_correct_answer_instance["instruction_output"][-1]).strip()
        pca_gt = pca_gt.replace("[", "")
        pca_gt = pca_gt.replace("]", "")
        pca_gt = pca_gt.replace("'", "")
        loose_pca_pred, _, _ = loose_parse_llm_output(print_correct_answer_instance)
        pca_loose_exact_match = compute_levenshtein_based_match(loose_pca_pred, pca_gt)

    else:
        pca_gt = ""
    
    if print_correct_answer_label_instance:
        pca_label_gt = str(print_correct_answer_label_instance["instruction_output"][-1]).strip()
        pca_label_gt = pca_label_gt.replace("[", "")
        pca_label_gt = pca_label_gt.replace("]", "")
        pca_label_gt = pca_label_gt.replace("'", "")
    else:
        pca_label_gt = ""

    return loose_exact_match, pca_loose_exact_match
    

def add_missing_labels_in_if_error(test_instance):
    """
    -- Dynamically add in if_error_set : 
    label + instruction_output[-1] , 
    label + "." + instruction_output[-1]
    """
    new_if_errors = []
    new_if_errors.append(str(test_instance["ground_truth_answer_label"]).strip() + str(test_instance["instruction_output"][-1]).strip())
    new_if_errors.append(str(test_instance["ground_truth_answer_label"]).strip() + "." + str(test_instance["instruction_output"][-1]).strip())

    return new_if_errors


def compute_reasoning_if_errors(test_instance: dict, print_correct_answer_instance: dict, print_correct_answer_label_instance: dict, strict_em_after_response: int, loose_v2_em: int, new_reasoning_error_dict: dict, if_or_noif_flag: str):
    loose_reasoning_error = 0
    strict_reasoning_error = 0
    loose_if_error = 0
    strict_if_error = 0
    loose_unclassified_error = 0
    strict_unclassified_error = 0

    # This is now the new loose version
    strict_prediction, loose_prediction, response_present_flag = loose_error_analysis_parse_llm_output(test_instance)
    strict_prediction = str(strict_prediction).strip()
    strict_prediction = strict_prediction.replace("[", "")
    strict_prediction = strict_prediction.replace("]", "")
    strict_prediction = strict_prediction.replace("'", "")

    key_val = test_instance["dataset"] + "_" + test_instance["dataset_input"].strip() + "_" + test_instance["instruction_id"]
    if key_val in new_reasoning_error_dict:
        reasoning_set = new_reasoning_error_dict[key_val]["reasoning_error_set"]
        if_set = new_reasoning_error_dict[key_val]["instruction_following_errors_set"]
    else:
        missing_stuff.append(if_or_noif_flag + " , " + test_instance["dataset"] + " , " + test_instance["instruction_id"])
        reasoning_set = test_instance["reasoning_error_set"]
        if_set = test_instance["instruction_following_errors_set"]

    if_set.append(add_missing_labels_in_if_error(test_instance))

    for idx, inst in enumerate(reasoning_set):
        reasoning_set[idx] = str(reasoning_set[idx]).strip()
        reasoning_set[idx] = reasoning_set[idx].replace("[", "")
        reasoning_set[idx] = reasoning_set[idx].replace("]", "")
        reasoning_set[idx] = reasoning_set[idx].replace("'", "")
    
    for idx, inst in enumerate(if_set):
        if_set[idx] = str(if_set[idx]).strip()
        if_set[idx] = if_set[idx].replace("[", "")
        if_set[idx] = if_set[idx].replace("]", "")
        if_set[idx] = if_set[idx].replace("'", "")

    if print_correct_answer_instance:
        pca_reasoning_set = print_correct_answer_instance["reasoning_error_set"]
        pca_if_set = print_correct_answer_instance["instruction_following_errors_set"]
        for idx, inst in enumerate(pca_reasoning_set):
            pca_reasoning_set[idx] = str(pca_reasoning_set[idx]).strip()
            pca_reasoning_set[idx] = pca_reasoning_set[idx].replace("[", "")
            pca_reasoning_set[idx] = pca_reasoning_set[idx].replace("]", "")
            pca_reasoning_set[idx] = pca_reasoning_set[idx].replace("'", "")
        
        for idx, inst in enumerate(pca_if_set):
            pca_if_set[idx] = str(pca_if_set[idx]).strip()
            pca_if_set[idx] = pca_if_set[idx].replace("[", "")
            pca_if_set[idx] = pca_if_set[idx].replace("]", "")
            pca_if_set[idx] = pca_if_set[idx].replace("'", "")
    else:
        pca_reasoning_set = []
        pca_if_set = []
    
    
    if print_correct_answer_label_instance:
        pca_label_reasoning_set = print_correct_answer_label_instance["reasoning_error_set"]
        pca_label_if_set = print_correct_answer_label_instance["instruction_following_errors_set"]
        for idx, inst in enumerate(pca_label_reasoning_set):
            pca_label_reasoning_set[idx] = str(pca_label_reasoning_set[idx]).strip()
            pca_label_reasoning_set[idx] = pca_label_reasoning_set[idx].replace("[", "")
            pca_label_reasoning_set[idx] = pca_label_reasoning_set[idx].replace("]", "")
            pca_label_reasoning_set[idx] = pca_label_reasoning_set[idx].replace("'", "")
        
        for idx, inst in enumerate(pca_label_if_set):
            pca_label_if_set[idx] = str(pca_label_if_set[idx]).strip()
            pca_label_if_set[idx] = pca_label_if_set[idx].replace("[", "")
            pca_label_if_set[idx] = pca_label_if_set[idx].replace("]", "")
            pca_label_if_set[idx] = pca_label_if_set[idx].replace("'", "")
    else:
        pca_label_reasoning_set = []
        pca_label_if_set = []
        
    # if strict_em_after_response == 0:
    if loose_v2_em == 0:
        for reason_candidate in reasoning_set:
            if len(reason_candidate) > REASONING_OR_IF_CANDIDATE_LENGTH_THRESHOLD:
                strict_reasoning_error = compute_reasoning_analysis_levenshtein_based_match(strict_prediction, reason_candidate)
            if (strict_prediction == reason_candidate) or compare_removing_whitespace(reason_candidate, strict_prediction):
                strict_reasoning_error = 1
            if strict_reasoning_error == 1:
                break
        
        for if_candidate in if_set:
            if len(if_candidate) > REASONING_OR_IF_CANDIDATE_LENGTH_THRESHOLD:
                strict_if_error = compute_reasoning_analysis_levenshtein_based_match(strict_prediction, if_candidate)
            if (strict_prediction == if_candidate) or compare_removing_whitespace(if_candidate, strict_prediction):
                strict_if_error = 1
            if strict_if_error == 1:
                break

        if strict_if_error == 0 and strict_reasoning_error == 0:
            strict_unclassified_error = 1
            
    if loose_v2_em == 0:
        for candidate in reasoning_set:
            if str(candidate) in loose_prediction:
                loose_reasoning_error = 1
                break
        for candidate in if_set:
            if str(candidate) in loose_prediction:
                loose_if_error = 1
                break
        if loose_reasoning_error == 0 and loose_if_error == 0:
            loose_unclassified_error = 1
    
    pca_reasoning = 0
    pca_if = 0
    pca_unclass = 0
    pca_label_reasoning = 0
    pca_label_if = 0
    pca_label_unclass = 0
    if print_correct_answer_instance and loose_v2_em == 0:
        strict_pca_pred, loose_pca_pred, _ = loose_error_analysis_parse_llm_output(print_correct_answer_instance)
        for pca_reason_candidate in pca_reasoning_set:
            if len(pca_reason_candidate) > REASONING_OR_IF_CANDIDATE_LENGTH_THRESHOLD:
                pca_reasoning = compute_reasoning_analysis_levenshtein_based_match(strict_pca_pred, pca_reason_candidate)
            if (strict_pca_pred == pca_reason_candidate) or compare_removing_whitespace(pca_reason_candidate, strict_pca_pred):
                pca_reasoning = 1
            if pca_reasoning == 1:
                break
        for pca_if_candidate in pca_if_set:
            if len(pca_if_candidate) > REASONING_OR_IF_CANDIDATE_LENGTH_THRESHOLD:
                pca_if = compute_reasoning_analysis_levenshtein_based_match(strict_pca_pred, pca_if_candidate)
            if (strict_pca_pred == pca_if_candidate) or compare_removing_whitespace(pca_if_candidate, strict_pca_pred):
                pca_if = 1
            if pca_if == 1:
                break
        if pca_reasoning == 0 and pca_if == 0:
            pca_unclass = 1
    
    
    return strict_reasoning_error, loose_reasoning_error, strict_if_error, loose_if_error, strict_unclassified_error, loose_unclassified_error, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass


def add_count_to_metrics(metrics_dict, strict_after_response_em, loose_v2_em, loose_after_response_em, loose_response_em, no_response_em, no_response_count_em, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reasoning_error, loose_reasoning_error, strict_if_error, loose_if_error, strict_unclassified_error, loose_unclassified_error, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count):
    metrics_dict["num_instances"] += 1
    metrics_dict["exact_match"]["strict_after_response"] += strict_after_response_em
    metrics_dict["exact_match"]["loose_v2"] += loose_v2_em
    metrics_dict["exact_match"]["loose_after_response"] += loose_after_response_em
    metrics_dict["exact_match"]["loose"] += loose_response_em
    metrics_dict["exact_match"]["no_response"] += no_response_em
    metrics_dict["exact_match"]["no_response_count"] += no_response_count_em

    metrics_dict["exact_match"]["print_correct_answer"] += pca_exact_match
    metrics_dict["exact_match"]["print_correct_answer_v2"] += pca_loose_v2_em
    
    metrics_dict["exact_match"]["print_correct_answer_count"] += pca_count
    metrics_dict["exact_match"]["print_correct_answer_label"] += pca_label_exact_match
    metrics_dict["exact_match"]["print_correct_answer_label_count"] += pca_label_count

    metrics_dict["if_analysis"]["strict_reasoning"] += strict_reasoning_error
    
    metrics_dict["if_analysis"]["strict_reasoning_count"] += strict_reasoning_error
    metrics_dict["if_analysis"]["loose_reasoning"] += loose_reasoning_error
    metrics_dict["if_analysis"]["strict_if"] += strict_if_error

    metrics_dict["if_analysis"]["strict_if_count"] += strict_if_error
    metrics_dict["if_analysis"]["loose_if"] += loose_if_error

    metrics_dict["if_analysis"]["strict_unclassified"] += strict_unclassified_error
    if strict_unclassified_error == 1:
        metrics_dict["if_analysis"]["strict_unclassified_count"] += 1
    metrics_dict["if_analysis"]["loose_unclassified"] += loose_unclassified_error

    metrics_dict["if_analysis"]["pca_reasoning"] += pca_reasoning
    if pca_reasoning == 1:
        metrics_dict["if_analysis"]["pca_reasoning_count"] += 1
    metrics_dict["if_analysis"]["pca_if"] += pca_if
    if pca_if == 1:
        metrics_dict["if_analysis"]["pca_if_count"] += 1
    metrics_dict["if_analysis"]["pca_unclassified"] += pca_unclass
    if pca_unclass == 1:
        metrics_dict["if_analysis"]["pca_unclassified_count"] += 1
    metrics_dict["if_analysis"]["pca_label_reasoning"] += pca_label_reasoning
    if pca_label_reasoning == 1:
        metrics_dict["if_analysis"]["pca_label_reasoning_count"] += 1
    metrics_dict["if_analysis"]["pca_label_if"] += pca_label_if
    if pca_label_if == 1:
        metrics_dict["if_analysis"]["pca_label_if_count"] += 1
    metrics_dict["if_analysis"]["pca_label_unclassified"] += pca_label_unclass
    if pca_label_unclass == 1:
        metrics_dict["if_analysis"]["pca_label_unclassified_count"] += 1
    return metrics_dict

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--file", required=False, help="Jsonl file path with model output"
    )
    parser.add_argument(
            "--folder", required=False, help="Parent folder container model subfolders, each with jsonl result files"
            )
    parser.add_argument(
            "--output", required=False, help="output file path and name for xlsx results file"
    )
    parser.add_argument(
            "--output_folder", required=False, help="output folder path"
    )
    args = parser.parse_args()

    files = []

    new_if_reasoning_dict, new_noif_reasoning_dict = metric_utils.new_reasoning_error_instances()
    print("-- generated new reasoning error dicts")

    benchmark_folder_path = "generations_to_run"
    benchmark_folder_files = Path(benchmark_folder_path).rglob("*.jsonl")
    benchmark_folder_files = [str(pth) for pth in benchmark_folder_files]
    
    if args.folder:
        files = Path(args.folder).rglob("*.jsonl")
        files = [str(pth) for pth in files]
    else:
        files = [str(args.file)]
        assert(args.output is not None)
    
    for file_path in tqdm(files, desc="For each file.."):
        model_name = file_path.rsplit("/", 1)[0].split("/")[-1]
        if "no_instruction" in model_name:
            model_name = model_name.split("_no_instruction")[0]
        else:
            model_name = model_name.split("_instruction")[0]
        # Constructing the print_correct_answer and print_correct_answer_label dict
        corrected_pca_path = ""
        for f in benchmark_folder_files:
            if model_name in f:
                corrected_pca_path = f
                break
        
        if corrected_pca_path != "":
            with open(corrected_pca_path) as f:
                corrected_pca_instances = [json.loads(line) for line in f]
        else:
            corrected_pca_instances = []
        with open(file_path) as f:
            benchmark_current_data = [json.loads(line) for line in f]
        benchmark_dict_print_correct_answer = {}
        benchmark_dict_print_correct_answer_label = {}
        if "_no_instruction" in file_path:
            instruct_follow_flag = "True"
        else:
            instruct_follow_flag = "False"
        for instance in corrected_pca_instances:
            # if instruct_follow_flag != instance["instruct_follow"]:
            #     continue
            if instance["instruction_id"] == "print_correct_answer":
                benchmark_dict_print_correct_answer[instance["dataset_input"]] = instance
            elif instance["instruction_id"] == "print_correct_answer_label":
                benchmark_dict_print_correct_answer_label[instance["dataset_input"]] = instance
        for instance in benchmark_current_data:
            if instance["instruction_id"] not in ["print_correct_answer", "print_correct_answer_label"]:
                continue
            if instance["instruction_id"] == "print_correct_answer":
                if instance["dataset_input"] not in benchmark_dict_print_correct_answer:
                    benchmark_dict_print_correct_answer[instance["dataset_input"]] = instance
            elif instance["instruction_id"] == "print_correct_answer_label":
                if instance["dataset_input"] not in benchmark_dict_print_correct_answer_label:
                    benchmark_dict_print_correct_answer_label[instance["dataset_input"]] = instance

        ##############################
        print(f"processing file --- {file_path}")
        with open(file_path) as f:
            data = [json.loads(line) for line in f]

        if args.folder:
            current_path_split = file_path.rsplit("/", 1)
            filename = current_path_split[0].split("/", 1)[-1] + "_results.xlsx"
            if args.output_folder:
                output_file = str(args.output_folder) + "/" + filename
            else:
                output_file = current_path_split[0] + "/" + filename
        else:
            output_file = args.output

        if "no_instruction" in file_path:
            new_reasoning_error_dict = new_noif_reasoning_dict
            if_or_noif_flag = "no_instr_follow"
        else:
            new_reasoning_error_dict = new_if_reasoning_dict
            if_or_noif_flag = "instr_follow"
            

        dataset_metrics = metric_utils.initialize_dataset_metrics(data)
        instruction_metrics = metric_utils.initialize_instruction_metrics(data)
        classification_metrics = metric_utils.initialize_classification_metrics(data)
        data_instr_metrics = metric_utils.initialize_dataset_instr_metrics(data)
        class_instr_metrics = metric_utils.initialize_classification_instr_metrics(data)
        class_data_metrics = metric_utils.initialize_classification_dataset_metrics(data)
        class_data_noif_metrics = metric_utils.initialize_classification_dataset_no_if_metrics(data)

        category_mapping = constants.CATEGORY_MAPPING()
        datasets = constants.DATASETS().datasets
        
        for test_instance in data:
            current_dataset = test_instance["dataset"]
            current_instruction = test_instance["instruction_id"]
            current_class = category_mapping.mapping[current_instruction]
            
            if current_dataset in ["BoolQ", "Winogrande", "Piqa"] and current_class == "Operations on List":
                continue

            current_parent_dataset = ""
            current_dataset_input = test_instance["dataset_input"]
            print_correct_answer_instance = benchmark_dict_print_correct_answer.get(current_dataset_input, {})
            print_correct_answer_label_instance = benchmark_dict_print_correct_answer_label.get(current_dataset_input, {})
            pca_count = 0 if print_correct_answer_instance == {} else 1
            pca_label_count = 0 if print_correct_answer_label_instance == {} else 1
            for dset in datasets:
                if dset in current_dataset:
                    current_parent_dataset = dset
                    break
            assert(current_parent_dataset != "")
            
            
            strict_exact_match_after_response, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_label_exact_match = compute_exact_match(test_instance, print_correct_answer_instance, print_correct_answer_label_instance)
            loose_v2_em, pca_loose_v2_em = compute_loose_exact_match(test_instance, print_correct_answer_instance, print_correct_answer_label_instance)
            strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass = compute_reasoning_if_errors(test_instance, print_correct_answer_instance, print_correct_answer_label_instance, strict_exact_match_after_response, loose_v2_em, new_reasoning_error_dict, if_or_noif_flag)

            dataset_metrics[current_dataset] = add_count_to_metrics(dataset_metrics[current_dataset], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
            
            instruction_metrics[current_instruction] = add_count_to_metrics(instruction_metrics[current_instruction], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
            
            if (current_class in classification_metrics.keys() or current_instruction in classification_metrics.keys()):
                if current_instruction in ["print_correct_answer_label", "print_correct_answer"]:
                    class_variable = current_instruction
                else:
                    class_variable = current_class
                classification_metrics[class_variable] = add_count_to_metrics(classification_metrics[class_variable], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
            
            if current_dataset in data_instr_metrics.keys() and current_instruction in data_instr_metrics[current_dataset].keys():
                data_instr_metrics[current_dataset][current_instruction] = add_count_to_metrics(data_instr_metrics[current_dataset][current_instruction], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
            
            if current_class in class_instr_metrics.keys() and current_instruction in class_instr_metrics[current_class].keys():
                class_instr_metrics[current_class][current_instruction] = add_count_to_metrics(class_instr_metrics[current_class][current_instruction], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
 
            if (current_class in class_data_metrics.keys() or current_instruction in class_data_metrics.keys()):
                if current_instruction in ["print_correct_answer_label", "print_correct_answer"]:
                    class_variable = current_instruction
                elif current_class in class_data_metrics.keys():
                    class_variable = current_class
                else:
                    class_variable = current_instruction
                if current_parent_dataset in class_data_metrics[class_variable].keys():
                    class_data_metrics[class_variable][current_parent_dataset] = add_count_to_metrics(class_data_metrics[class_variable][current_parent_dataset], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
                
            if (current_class in class_data_noif_metrics.keys() or current_instruction in class_data_noif_metrics.keys()):
                if current_instruction in ["print_correct_answer_label", "print_correct_answer"]:
                    class_variable = current_instruction
                elif current_class in class_data_noif_metrics.keys():
                    class_variable = current_class
                else:
                    class_variable = current_instruction
                if current_parent_dataset in class_data_noif_metrics[class_variable].keys():
                    class_data_noif_metrics[class_variable][current_parent_dataset] = add_count_to_metrics(class_data_noif_metrics[class_variable][current_parent_dataset], strict_exact_match_after_response, loose_v2_em, loose_exact_match_after_response, loose_exact_match, exact_match_no_response, no_response_count, pca_exact_match, pca_loose_v2_em, pca_label_exact_match, strict_reason, loose_reason, strict_if, loose_if, strict_unclass, loose_unclass, pca_reasoning, pca_if, pca_unclass, pca_label_reasoning, pca_label_if, pca_label_unclass, pca_count, pca_label_count)
        
        metric_utils.write_metrics_to_csv(dataset_metrics, instruction_metrics, classification_metrics, data_instr_metrics, class_instr_metrics, class_data_metrics, class_data_noif_metrics, output_file)


if __name__ == "__main__":
    main()