from util.constants import CLASSIFICATION, CATEGORY_MAPPING, DATASETS, TABLE_ORDERING
import json
import pandas as pd
from pathlib import Path

model_table_ordering = TABLE_ORDERING().models
category_table_ordering = TABLE_ORDERING().categories


def levenshtein_no_replace(s1, s2):
    # Create a distance matrix
    len_s1, len_s2 = len(s1), len(s2)
    dp = [[0] * (len_s2 + 1) for _ in range(len_s1 + 1)]

    # Initialize the matrix
    for i in range(len_s1 + 1):
        dp[i][0] = i  # Only deletion possible for s1
    for j in range(len_s2 + 1):
        dp[0][j] = j  # Only insertion possible for s2

    # Fill the matrix, ignoring the replace operation
    for i in range(1, len_s1 + 1):
        for j in range(1, len_s2 + 1):
            if s1[i - 1] == s2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]  # No operation needed
            else:
                dp[i][j] = min(dp[i - 1][j] + 1,  # Deletion
                               dp[i][j - 1] + 1)  # Insertion

    # The final cell contains the Levenshtein distance without replace
    return dp[len_s1][len_s2]


def new_reasoning_error_instances():
    if_follow_path = "dataset_withreasoning_error_fix/instruction_follow"
    noif_follow_path = "dataset_withreasoning_error_fix/no_instruction_follow"

    if_new_reasoning_dict = {}
    noif_new_reasoning_dict = {}

    if_follow_files = Path(if_follow_path).rglob("*.jsonl")
    if_follow_files = [str(pth) for pth in if_follow_files]
    noif_follow_files = Path(noif_follow_path).rglob("*.jsonl")
    noif_follow_files = [str(pth) for pth in noif_follow_files]

    for file in if_follow_files:
        with open(file, 'r') as f:
            data = [json.loads(line) for line in f]
            for instance in data:
                key_val = instance["dataset"] + "_" + instance["dataset_input"].strip() + "_" + instance["instruction_id"]
                if_new_reasoning_dict[key_val] = instance
    
    for file in noif_follow_files:
        with open(file, 'r') as f:
            data = [json.loads(line) for line in f]
            for instance in data:
                key_val = instance["dataset"] + "_" + instance["dataset_input"].strip() + "_" + instance["instruction_id"]
                noif_new_reasoning_dict[key_val] = instance
    return if_new_reasoning_dict, noif_new_reasoning_dict
    

def initialize_dataset_metrics(data):
    metrics = {}
    unique_datasets = [tst["dataset"] for tst in data]
    unique_datasets = list(set(unique_datasets))
    for dataset in unique_datasets:
        metrics[dataset] = {}
        metrics[dataset]["num_instances"] = 0
        metrics[dataset]["exact_match"] = {}
        metrics[dataset]["exact_match"]["strict_after_response"] = 0
        metrics[dataset]["exact_match"]["loose_v2"] = 0
        metrics[dataset]["exact_match"]["loose_after_response"] = 0
        metrics[dataset]["exact_match"]["loose"] = 0
        metrics[dataset]["exact_match"]["no_response"] = 0
        metrics[dataset]["exact_match"]["no_response_count"] = 0
        metrics[dataset]["exact_match"]["print_correct_answer"] = 0
        metrics[dataset]["exact_match"]["print_correct_answer_v2"] = 0
        metrics[dataset]["exact_match"]["print_correct_answer_count"] = 0
        metrics[dataset]["exact_match"]["print_correct_answer_label"] = 0
        metrics[dataset]["exact_match"]["print_correct_answer_label_count"] = 0
        metrics[dataset]["if_analysis"] = {}
        metrics[dataset]["if_analysis"]["strict_reasoning"] = 0
        metrics[dataset]["if_analysis"]["strict_reasoning_count"] = 0
        metrics[dataset]["if_analysis"]["loose_reasoning"] = 0
        metrics[dataset]["if_analysis"]["strict_if"] = 0
        metrics[dataset]["if_analysis"]["strict_if_count"] = 0
        metrics[dataset]["if_analysis"]["loose_if"] = 0
        metrics[dataset]["if_analysis"]["strict_unclassified"] = 0
        metrics[dataset]["if_analysis"]["strict_unclassified_count"] = 0
        metrics[dataset]["if_analysis"]["loose_unclassified"] = 0
        metrics[dataset]["if_analysis"]["pca_reasoning"] = 0
        metrics[dataset]["if_analysis"]["pca_reasoning_count"] = 0
        metrics[dataset]["if_analysis"]["pca_if"] = 0
        metrics[dataset]["if_analysis"]["pca_if_count"] = 0
        metrics[dataset]["if_analysis"]["pca_unclassified"] = 0
        metrics[dataset]["if_analysis"]["pca_unclassified_count"] = 0
        metrics[dataset]["if_analysis"]["pca_label_reasoning"] = 0
        metrics[dataset]["if_analysis"]["pca_label_reasoning_count"] = 0
        metrics[dataset]["if_analysis"]["pca_label_if"] = 0
        metrics[dataset]["if_analysis"]["pca_label_if_count"] = 0
        metrics[dataset]["if_analysis"]["pca_label_unclassified"] = 0
        metrics[dataset]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics

def initialize_instruction_metrics(data):
    metrics = {}
    unique_instructions = [tst["instruction_id"] for tst in data]
    unique_instructions = list(set(unique_instructions))
    for instruction in unique_instructions:
        metrics[instruction] = {}
        metrics[instruction]["num_instances"] = 0
        metrics[instruction]["exact_match"] = {}
        metrics[instruction]["exact_match"]["strict_after_response"] = 0
        metrics[instruction]["exact_match"]["loose_v2"] = 0
        metrics[instruction]["exact_match"]["loose_after_response"] = 0
        metrics[instruction]["exact_match"]["loose"] = 0
        metrics[instruction]["exact_match"]["no_response"] = 0
        metrics[instruction]["exact_match"]["no_response_count"] = 0
        metrics[instruction]["exact_match"]["print_correct_answer"] = 0
        metrics[instruction]["exact_match"]["print_correct_answer_v2"] = 0
        metrics[instruction]["exact_match"]["print_correct_answer_count"] = 0
        metrics[instruction]["exact_match"]["print_correct_answer_label"] = 0
        metrics[instruction]["exact_match"]["print_correct_answer_label_count"] = 0
        metrics[instruction]["if_analysis"] = {}
        metrics[instruction]["if_analysis"]["strict_reasoning"] = 0
        metrics[instruction]["if_analysis"]["strict_reasoning_count"] = 0
        metrics[instruction]["if_analysis"]["loose_reasoning"] = 0
        metrics[instruction]["if_analysis"]["strict_if"] = 0
        metrics[instruction]["if_analysis"]["strict_if_count"] = 0
        metrics[instruction]["if_analysis"]["loose_if"] = 0
        metrics[instruction]["if_analysis"]["strict_unclassified"] = 0
        metrics[instruction]["if_analysis"]["strict_unclassified_count"] = 0
        metrics[instruction]["if_analysis"]["loose_unclassified"] = 0
        metrics[instruction]["if_analysis"]["pca_reasoning"] = 0
        metrics[instruction]["if_analysis"]["pca_reasoning_count"] = 0
        metrics[instruction]["if_analysis"]["pca_if"] = 0
        metrics[instruction]["if_analysis"]["pca_if_count"] = 0
        metrics[instruction]["if_analysis"]["pca_unclassified"] = 0
        metrics[instruction]["if_analysis"]["pca_unclassified_count"] = 0
        metrics[instruction]["if_analysis"]["pca_label_reasoning"] = 0
        metrics[instruction]["if_analysis"]["pca_label_reasoning_count"] = 0
        metrics[instruction]["if_analysis"]["pca_label_if"] = 0
        metrics[instruction]["if_analysis"]["pca_label_if_count"] = 0
        metrics[instruction]["if_analysis"]["pca_label_unclassified"] = 0
        metrics[instruction]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics

def initialize_classification_metrics(data):
    metrics = {}
    category_map = CATEGORY_MAPPING()

    unique_classes = [category_map.mapping[tst["instruction_id"]] for tst in data if tst["instruction_id"] not in ["print_correct_answer_label", "print_correct_answer"]]
    unique_classes = list(set(unique_classes))
    if any(tst["instruction_id"] == "print_correct_answer_label" for tst in data):
        unique_classes.append("print_correct_answer_label")
    if any(tst["instruction_id"] == "print_correct_answer" for tst in data):
        unique_classes.append("print_correct_answer")
    
    for classification in unique_classes:
        metrics[classification] = {}
        metrics[classification]["num_instances"] = 0
        metrics[classification]["exact_match"] = {}
        metrics[classification]["exact_match"]["strict_after_response"] = 0
        metrics[classification]["exact_match"]["loose_v2"] = 0
        metrics[classification]["exact_match"]["loose_after_response"] = 0
        metrics[classification]["exact_match"]["loose"] = 0
        metrics[classification]["exact_match"]["no_response"] = 0
        metrics[classification]["exact_match"]["no_response_count"] = 0
        metrics[classification]["exact_match"]["print_correct_answer"] = 0
        metrics[classification]["exact_match"]["print_correct_answer_v2"] = 0
        metrics[classification]["exact_match"]["print_correct_answer_count"] = 0
        metrics[classification]["exact_match"]["print_correct_answer_label"] = 0
        metrics[classification]["exact_match"]["print_correct_answer_label_count"] = 0
        metrics[classification]["if_analysis"] = {}
        metrics[classification]["if_analysis"]["strict_reasoning"] = 0
        metrics[classification]["if_analysis"]["strict_reasoning_count"] = 0
        metrics[classification]["if_analysis"]["loose_reasoning"] = 0
        metrics[classification]["if_analysis"]["strict_if"] = 0
        metrics[classification]["if_analysis"]["strict_if_count"] = 0
        metrics[classification]["if_analysis"]["loose_if"] = 0
        metrics[classification]["if_analysis"]["strict_unclassified"] = 0
        metrics[classification]["if_analysis"]["strict_unclassified_count"] = 0
        metrics[classification]["if_analysis"]["loose_unclassified"] = 0
        metrics[classification]["if_analysis"]["pca_reasoning"] = 0
        metrics[classification]["if_analysis"]["pca_reasoning_count"] = 0
        metrics[classification]["if_analysis"]["pca_if"] = 0
        metrics[classification]["if_analysis"]["pca_if_count"] = 0
        metrics[classification]["if_analysis"]["pca_unclassified"] = 0
        metrics[classification]["if_analysis"]["pca_unclassified_count"] = 0
        metrics[classification]["if_analysis"]["pca_label_reasoning"] = 0
        metrics[classification]["if_analysis"]["pca_label_reasoning_count"] = 0
        metrics[classification]["if_analysis"]["pca_label_if"] = 0
        metrics[classification]["if_analysis"]["pca_label_if_count"] = 0
        metrics[classification]["if_analysis"]["pca_label_unclassified"] = 0
        metrics[classification]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics
        

def initialize_dataset_instr_metrics(data):
    metrics = {}
    category_map = CATEGORY_MAPPING().mapping
    unique_datasets = [tst["dataset"] for tst in data]
    unique_datasets = list(set(unique_datasets))
    for dataset in unique_datasets:
        subset = list(filter(lambda d: d["dataset"] == dataset, data))
        unique_instructions = [tst["instruction_id"] for tst in subset]
        unique_instructions = list(set(unique_instructions))
        if len(unique_instructions) > 0:
            metrics[dataset] = {}
        for instruction in unique_instructions:
            if dataset in ["Piqa", "BoolQ", "Winogrande"] and category_map[instruction] == "Operations on List":
                continue
            metrics[dataset][instruction] = {}
            metrics[dataset][instruction]["num_instances"] = 0
            metrics[dataset][instruction]["exact_match"] = {}
            metrics[dataset][instruction]["exact_match"]["strict_after_response"] = 0
            metrics[dataset][instruction]["exact_match"]["loose_v2"] = 0
            metrics[dataset][instruction]["exact_match"]["loose_after_response"] = 0
            metrics[dataset][instruction]["exact_match"]["loose"] = 0
            metrics[dataset][instruction]["exact_match"]["no_response"] = 0
            metrics[dataset][instruction]["exact_match"]["no_response_count"] = 0
            metrics[dataset][instruction]["exact_match"]["print_correct_answer"] = 0
            metrics[dataset][instruction]["exact_match"]["print_correct_answer_v2"] = 0
            metrics[dataset][instruction]["exact_match"]["print_correct_answer_count"] = 0
            metrics[dataset][instruction]["exact_match"]["print_correct_answer_label"] = 0
            metrics[dataset][instruction]["exact_match"]["print_correct_answer_label_count"] = 0
            metrics[dataset][instruction]["if_analysis"] = {}
            metrics[dataset][instruction]["if_analysis"]["strict_reasoning"] = 0
            metrics[dataset][instruction]["if_analysis"]["strict_reasoning_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["loose_reasoning"] = 0
            metrics[dataset][instruction]["if_analysis"]["strict_if"] = 0
            metrics[dataset][instruction]["if_analysis"]["strict_if_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["loose_if"] = 0
            metrics[dataset][instruction]["if_analysis"]["strict_unclassified"] = 0
            metrics[dataset][instruction]["if_analysis"]["strict_unclassified_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["loose_unclassified"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_reasoning"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_reasoning_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_if"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_if_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_unclassified"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_unclassified_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_reasoning"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_reasoning_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_if"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_if_count"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_unclassified"] = 0
            metrics[dataset][instruction]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics

def initialize_classification_instr_metrics(data):
    metrics = {}
    category_map = CATEGORY_MAPPING()

    unique_classes = [category_map.mapping[tst["instruction_id"]] for tst in data]
    unique_classes = list(set(unique_classes))
    
    for classification in unique_classes:
        subset = list(filter(lambda d: category_map.mapping[d["instruction_id"]] == classification, data))
        unique_instructions = [tst["instruction_id"] for tst in subset]
        unique_instructions = list(set(unique_instructions))
        if len(unique_instructions) > 0:
            metrics[classification] = {}
        for instruction in unique_instructions:
            metrics[classification][instruction] = {}
            metrics[classification][instruction]["num_instances"] = 0
            metrics[classification][instruction]["exact_match"] = {}
            metrics[classification][instruction]["exact_match"]["strict_after_response"] = 0
            metrics[classification][instruction]["exact_match"]["loose_v2"] = 0
            metrics[classification][instruction]["exact_match"]["loose_after_response"] = 0
            metrics[classification][instruction]["exact_match"]["loose"] = 0
            metrics[classification][instruction]["exact_match"]["no_response"] = 0
            metrics[classification][instruction]["exact_match"]["no_response_count"] = 0
            metrics[classification][instruction]["exact_match"]["print_correct_answer"] = 0
            metrics[classification][instruction]["exact_match"]["print_correct_answer_v2"] = 0
            metrics[classification][instruction]["exact_match"]["print_correct_answer_count"] = 0
            metrics[classification][instruction]["exact_match"]["print_correct_answer_label"] = 0
            metrics[classification][instruction]["exact_match"]["print_correct_answer_label_count"] = 0
            metrics[classification][instruction]["if_analysis"] = {}
            metrics[classification][instruction]["if_analysis"]["strict_reasoning"] = 0
            metrics[classification][instruction]["if_analysis"]["strict_reasoning_count"] = 0
            metrics[classification][instruction]["if_analysis"]["loose_reasoning"] = 0
            metrics[classification][instruction]["if_analysis"]["strict_if"] = 0
            metrics[classification][instruction]["if_analysis"]["strict_if_count"] = 0
            metrics[classification][instruction]["if_analysis"]["loose_if"] = 0
            metrics[classification][instruction]["if_analysis"]["strict_unclassified"] = 0
            metrics[classification][instruction]["if_analysis"]["strict_unclassified_count"] = 0
            metrics[classification][instruction]["if_analysis"]["loose_unclassified"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_reasoning"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_reasoning_count"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_if"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_if_count"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_unclassified"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_unclassified_count"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_reasoning"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_reasoning_count"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_if"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_if_count"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_unclassified"] = 0
            metrics[classification][instruction]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics

def initialize_classification_dataset_metrics(data):
    metrics = {}
    category_map = CATEGORY_MAPPING()
    datasets = DATASETS().datasets
    
    unique_classes = [category_map.mapping[tst["instruction_id"]] for tst in data if tst["instruction_id"] not in ["print_correct_answer_label", "print_correct_answer"]]
    unique_classes = list(set(unique_classes))
    if any(tst["instruction_id"] == "print_correct_answer_label" for tst in data):
        unique_classes.append("print_correct_answer_label")
    if any(tst["instruction_id"] == "print_correct_answer" for tst in data):
        unique_classes.append("print_correct_answer")

    
    for classification in unique_classes:
        if classification == "print_correct_answer_label":
            subset = list(filter(lambda d: d["instruction_id"] == classification, data))
        elif classification == "print_correct_answer":
            subset = list(filter(lambda d: d["instruction_id"] == classification, data))
        else:
            subset = list(filter(lambda d: category_map.mapping[d["instruction_id"]] == classification, data))
        unique_dataset = [tst["dataset"] for tst in subset]
        parent_level_unique_dataset = []
        for dset in unique_dataset:
            for dataset in datasets:
                if dataset in dset:
                    parent_level_unique_dataset.append(dataset)
                    break
        unique_dataset = list(set(parent_level_unique_dataset))
                    
        if len(unique_dataset) > 0:
            metrics[classification] = {}
        for dataset in unique_dataset:
            if classification == "Operations on List" and dataset in ["Piqa", "Winogrande", "BoolQ"]:
                continue
            metrics[classification][dataset] = {}
            metrics[classification][dataset]["num_instances"] = 0
            metrics[classification][dataset]["exact_match"] = {}
            metrics[classification][dataset]["exact_match"]["strict_after_response"] = 0
            metrics[classification][dataset]["exact_match"]["loose_v2"] = 0
            metrics[classification][dataset]["exact_match"]["loose_after_response"] = 0
            metrics[classification][dataset]["exact_match"]["loose"] = 0
            metrics[classification][dataset]["exact_match"]["no_response"] = 0
            metrics[classification][dataset]["exact_match"]["no_response_count"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_v2"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_count"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_label"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_label_count"] = 0
            metrics[classification][dataset]["if_analysis"] = {}
            metrics[classification][dataset]["if_analysis"]["strict_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_if"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_if"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_unclassified_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_if"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_unclassified_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_if"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics

#Only used in the paper for no instruction following table
# Computing for everything just to make life easier.
def initialize_classification_dataset_no_if_metrics(data):
    metrics = {}
    category_map = CATEGORY_MAPPING()
    datasets = DATASETS().datasets

    unique_classes = [category_map.mapping[tst["instruction_id"]] for tst in data if tst["instruction_id"] not in ["print_correct_answer_label", "print_correct_answer"]]
    unique_classes = list(set(unique_classes))
    if any(tst["instruction_id"] == "print_correct_answer_label" for tst in data):
        unique_classes.append("print_correct_answer_label")
    if any(tst["instruction_id"] == "print_correct_answer" for tst in data):
        unique_classes.append("print_correct_answer")
    
    for classification in unique_classes:
        if classification == "print_correct_answer":
            subset = list(filter(lambda d: d["instruction_id"] == classification, data))
        elif classification == "print_correct_answer_label":
            subset = list(filter(lambda d: d["instruction_id"] == classification, data))
        else:
            subset = list(filter(lambda d: category_map.mapping[d["instruction_id"]] == classification, data))
        unique_dataset = [tst["dataset"] for tst in subset]
        parent_level_unique_dataset = []
        for dset in unique_dataset:
            for dataset in datasets:
                if dataset in dset:
                    parent_level_unique_dataset.append(dataset)
                    break
        unique_dataset = list(set(parent_level_unique_dataset))
                    
        if len(unique_dataset) > 0:
            metrics[classification] = {}
        for dataset in unique_dataset:
            if classification == "Operations on List" and dataset in ["Piqa", "Winogrande", "BoolQ"]:
                continue
            metrics[classification][dataset] = {}
            metrics[classification][dataset]["num_instances"] = 0
            metrics[classification][dataset]["exact_match"] = {}
            metrics[classification][dataset]["exact_match"]["strict_after_response"] = 0
            metrics[classification][dataset]["exact_match"]["loose_v2"] = 0
            metrics[classification][dataset]["exact_match"]["loose_after_response"] = 0
            metrics[classification][dataset]["exact_match"]["loose"] = 0
            metrics[classification][dataset]["exact_match"]["no_response"] = 0
            metrics[classification][dataset]["exact_match"]["no_response_count"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_v2"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_count"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_label"] = 0
            metrics[classification][dataset]["exact_match"]["print_correct_answer_label_count"] = 0
            metrics[classification][dataset]["if_analysis"] = {}
            metrics[classification][dataset]["if_analysis"]["strict_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_if"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_if"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["strict_unclassified_count"] = 0
            metrics[classification][dataset]["if_analysis"]["loose_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_if"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_unclassified_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_reasoning"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_reasoning_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_if"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_if_count"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_unclassified"] = 0
            metrics[classification][dataset]["if_analysis"]["pca_label_unclassified_count"] = 0
    return metrics


def table_metrics(values: dict):
    
    num_instances = values["num_instances"]
    orig_em = values["exact_match"]["strict_after_response"] / num_instances   
    orig_v2_em = values["exact_match"]["loose_v2"] / num_instances
    orig_reason = values["if_analysis"]["strict_reasoning"] / num_instances
    orig_if = values["if_analysis"]["strict_if"] / num_instances

    orig_loose_if = values["if_analysis"]["loose_if"] / num_instances
    orig_loose_reason = values["if_analysis"]["loose_reasoning"] / num_instances
    orig_loose_unclass = values["if_analysis"]["loose_unclassified"] / num_instances
    
    pca_count = values["exact_match"]["print_correct_answer_count"]
    pca_label_count = values["exact_match"]["print_correct_answer_label_count"]

    orig_reason_count = values["if_analysis"]["strict_reasoning_count"]
    orig_if_count = values["if_analysis"]["strict_if_count"]
    orig_unclass_count = values["if_analysis"]["strict_unclassified_count"]

    pca_reason_count = values["if_analysis"]["pca_reasoning_count"]
    pca_if_count = values["if_analysis"]["pca_if_count"]
    pca_unclass_count = values["if_analysis"]["pca_unclassified_count"]

    pca_label_reason_count = values["if_analysis"]["pca_label_reasoning_count"]
    pca_label_if_count = values["if_analysis"]["pca_label_if_count"]
    pca_label_unclass_count = values["if_analysis"]["pca_label_unclassified_count"]
    
    if pca_count > 0:
        pca_em = values["exact_match"]["print_correct_answer"] / pca_count
        pca_loose_v2 = values["exact_match"]["print_correct_answer_v2"] / pca_count
        pca_reason = values["if_analysis"]["pca_reasoning"] / pca_count
        pca_if = values["if_analysis"]["pca_if"] / pca_count
        pca_drop = pca_em - orig_em
        pca_reason_drop = pca_reason - orig_reason
        pca_if_drop = pca_if - orig_if
    else:
        pca_em = 0
        pca_loose_v2 = 0
        pca_reason = 0
        pca_if = 0
        pca_drop = -1
        pca_reason_drop = -1
        pca_if_drop = -1
    if pca_label_count > 0:
        pca_label_em = values["exact_match"]["print_correct_answer_label"] / pca_label_count
        pca_label_reason = values["if_analysis"]["pca_label_reasoning"] / pca_label_count
        pca_label_if = values["if_analysis"]["pca_label_if"] / pca_label_count
        pca_label_drop = pca_label_em - orig_em
        pca_label_reason_drop = pca_label_reason - orig_reason
        pca_label_if_drop = pca_label_if - orig_if
    else:
        pca_label_em = 0
        pca_label_reason = 0
        pca_label_if = 0
        pca_label_drop = -1
        pca_label_reason_drop = -1
        pca_label_if_drop = -1

    # return [pca_drop, pca_label_drop, num_instances, pca_count, pca_label_count, pca_reason_drop, pca_if_drop, pca_label_reason_drop, pca_label_if_drop]
    return [orig_em, pca_em, orig_v2_em, pca_loose_v2, pca_label_em, num_instances, pca_count, pca_label_count, 
            orig_reason, pca_reason, pca_label_reason, orig_if, pca_if, pca_label_if,
            orig_reason_count, pca_reason_count, pca_label_reason_count,
            orig_if_count, pca_if_count, pca_label_if_count,
            orig_unclass_count, pca_unclass_count, pca_label_unclass_count,
            orig_loose_reason, orig_loose_if, orig_loose_unclass]
     

def write_metrics_to_csv(dataset_metrics, instruction_metrics, classification_metrics, data_instr_metrics, class_instr_metrics, class_data_metrics, class_data_noif_metrics, output_file_path):
    dataset_csv = []
    instruction_csv = []
    classification_csv = []
    data_instr_csv = []
    class_instr_csv = []
    class_data_csv = []
    class_data_noif_csv = []

    common_headers = ["orig_em", "pca_em",  "orig_v2_em", "pca_loose_v2", "pca_label_em", "orig_count", "pca_count", "pca_label_count", 
                      "orig_reason", "pca_reason", "pca_label_reason", "orig_if", "pca_if", "pca_label_if",
                      "orig_reason_count", "pca_reason_count", "pca_label_reason_count",
                      "orig_if_count", "pca_if_count", "pca_label_if_count",
                      "orig_unclass_count", "pca_unclass_count", "pca_label_unclass_count",
                      "orig_loose_reason", "orig_loose_if", "orig_loose_unclass"]
    
    for dataset, values in dataset_metrics.items():
        common_metrics = table_metrics(values)
        common_metrics.insert(0, dataset)
        dataset_csv.append(common_metrics)
        
    for instruction, values in instruction_metrics.items():
        common_metrics = table_metrics(values)
        common_metrics.insert(0, instruction)
        instruction_csv.append(common_metrics)
    
    for classification, values in classification_metrics.items():
        common_metrics = table_metrics(values)
        common_metrics.insert(0, classification)
        classification_csv.append(common_metrics)
    
    for dataset, instructions in data_instr_metrics.items():
        for instruction, values in instructions.items():
            common_metrics = table_metrics(values)
            common_metrics.insert(0, instruction)
            common_metrics.insert(0, dataset)
            data_instr_csv.append(common_metrics)
 
    for classification, instructions in class_instr_metrics.items():
        for instruction, values in instructions.items():
            common_metrics = table_metrics(values)
            common_metrics.insert(0, instruction)
            common_metrics.insert(0, classification)
            class_instr_csv.append(common_metrics)
 # 
    for classification, datasets in class_data_metrics.items():
        for dataset, values in datasets.items():
            common_metrics = table_metrics(values)
            common_metrics.insert(0, dataset)
            common_metrics.insert(0, classification)
            class_data_csv.append(common_metrics)
 #
    for classification, datasets in class_data_noif_metrics.items():
        for dataset, values in datasets.items():
            common_metrics = table_metrics(values)
            common_metrics.insert(0, dataset)
            common_metrics.insert(0, classification)
            class_data_noif_csv.append(common_metrics)
    
    path = Path(output_file_path)
    path.parent.mkdir(parents=True, exist_ok=True)
    
    writer = pd.ExcelWriter(output_file_path, engine = 'xlsxwriter')
    dataset_headers = ["dataset"] + common_headers
    data_df = pd.DataFrame(dataset_csv, columns=dataset_headers)
    data_df.to_excel(writer, sheet_name='dataset', index=False)
    
    instruction_headers = ["instruction"] + common_headers
    instr_df = pd.DataFrame(instruction_csv, columns=instruction_headers)
    instr_df.to_excel(writer, sheet_name='instruction', index=False)
    
    classification_headers = ["classification"] + common_headers
    class_df = pd.DataFrame(classification_csv, columns=classification_headers)
    class_df['classification'] = pd.Categorical(class_df['classification'], ordered=True, categories=category_table_ordering)
    class_df = class_df.sort_values('classification')
    class_df.to_excel(writer, sheet_name='classification', index=False)
    
    data_instr_headers = ["dataset", "instruction"] + common_headers
    data_instr_df = pd.DataFrame(data_instr_csv, columns=data_instr_headers)
    data_instr_df.to_excel(writer, sheet_name='dataset_instructions', index=False)
    
    class_instr_headers = ["classification", "instruction"] + common_headers
    class_instr_df = pd.DataFrame(class_instr_csv, columns=class_instr_headers)
    class_instr_df['classification'] = pd.Categorical(class_instr_df['classification'], ordered=True, categories=category_table_ordering)
    class_instr_df = class_instr_df.sort_values('classification')
    class_instr_df.to_excel(writer, sheet_name='classification_instructions', index=False)
    
    class_data_headers = ["classification", "dataset"] + common_headers
    class_data_df = pd.DataFrame(class_data_csv, columns=class_data_headers)
    # class_data_df = class_data_df.replace("print_correct_answer_label", "No Manipulation")
    class_data_df['classification'] = pd.Categorical(class_data_df['classification'], ordered=True, categories=category_table_ordering)
    class_data_df = class_data_df.sort_values('classification')
    class_data_df.to_excel(writer, sheet_name='classification_dataset', index=False)
    
    class_data_noif_headers = ["classification", "dataset"] + common_headers
    class_data_noif_df = pd.DataFrame(class_data_noif_csv, columns=class_data_noif_headers)
    # class_data_noif_df = class_data_noif_df.replace("print_correct_answer", "No Manipulation")
    class_data_noif_df['classification'] = pd.Categorical(class_data_noif_df['classification'], ordered=True, categories=category_table_ordering)
    class_data_noif_df = class_data_noif_df.sort_values('classification')
    class_data_noif_df.to_excel(writer, sheet_name='noif_classification_dataset', index=False)
    writer.close()

    return

