import json
from sklearn.metrics import accuracy_score, f1_score
import argparse
import os
import sys


def get_dataset_file_path(image_path):
    base_path = image_path.rsplit('images', 1)[0]
    dataset_file = "diagram_info.json"
    dataset_file_path = os.path.join(base_path, dataset_file)
    return dataset_file_path


def get_dataset_info(dataset_file_path):
    with open(dataset_file_path, 'r') as file:
        dataset_info_dict = json.load(file)
    return dataset_info_dict


def get_dataset_count(dataset_info_dict, diagram_id, variable_type):
    dict_key = "shapes" if variable_type == "shape_count" else "arrows"
    return len(dataset_info_dict[diagram_id][dict_key])


def get_json_from_path_list(data_path_list):
    combined_data = []
    # Load data from the JSON files
    for path in data_path_list:
        try:
            # Open the JSON file and load its content
            with open(path, 'r') as file:
                data = json.load(file)
                # Check if the loaded data is a list
                if isinstance(data, list):
                    combined_data.extend(data)
                else:
                    raise ValueError(
                        f"The file {path} does not contain a list.")
        except FileNotFoundError:
            print(f"Error: The file {path} was not found.")
        except json.JSONDecodeError:
            print(f"Error: The file {path} is not a valid JSON file.")
        except ValueError as ve:
            print(f"Error: {ve}")
    return combined_data


def extract_answer_letter(model_answer):
    try:
        # Finding the last closing parenthesis and extracting the letter just before it
        answer_index = model_answer.rfind(")")
        if answer_index != -1:
            # Assume answer is the letter followed by optional punctuation just before ")"
            model_answer_letter = model_answer[:answer_index].split(
            )[-1].lower().strip()
            if model_answer_letter:
                model_answer_letter = model_answer_letter[-1]
            else:
                model_answer_letter = ""
        else:
            model_answer_letter = ""
    except IndexError:
        model_answer_letter = ""  # In case of parsing error or unexpected format
    return model_answer_letter


def calculate_metrics(data_path_list, q_type="", q_component="", variable_type="", variable_value=0, sample_choice="", sample_file=""):

    combined_data = get_json_from_path_list(data_path_list)
    # Initialize lists to hold true labels and model predictions
    # print(q_type, variable_value, data_path_list)
    y_true = []
    y_pred = []

    dataset_info_dict = {}
    first_match = True

    if sample_choice:
        with open(sample_file, 'r') as file:
            samples = json.load(file)

    # count_list = []

    # Extract correct answers and model answers
    for entry in combined_data:
        curr_val = 0
        if "count" in variable_type:
            curr_dataset_path = get_dataset_file_path(entry["image_path"])
            if curr_dataset_path not in dataset_info_dict:
                dataset_info_dict[curr_dataset_path] = get_dataset_info(
                    curr_dataset_path)
            curr_val = get_dataset_count(
                dataset_info_dict[curr_dataset_path], entry["q_id"], variable_type)
        elif "position" in variable_type:
            # print(variable_value,  entry['question'])
            curr_val = variable_value if variable_value in entry['question'] else ""

        dataset_name = "real_dqa"
        for path in data_path_list:
            if "synthetic" in path:
                dataset_name = "icon_dqa"

        if "icon_dqa" == dataset_name:
            criteria = (q_type == "all" or (q_type == "count" and entry['count_question']) or (
                q_type == "existence" and not entry['count_question'])) and curr_val == variable_value
        else:
            # and entry["question_type"] == q_type
            criteria = entry["question_component"] == q_component
            if sample_choice:
                image_name = entry["q_id"].rsplit('_', 1)[
                    0] + ".png"
                criteria = criteria and sample_choice in samples and (
                    image_name in samples[sample_choice])

        if criteria:
            if first_match:
                print(entry["q_id"])
                first_match = False
            # Ensure lowercase for consistency
            correct_answer_letter = entry['correct_answer_letter'].lower()
            # Attempt to extract the predicted answer from model_answer
            # Looking for the letter before the last closing parenthesis
            model_answer = entry['model_answer']
            model_answer_letter = extract_answer_letter(model_answer)
            # Append to lists
            y_true.append(correct_answer_letter)
            y_pred.append(model_answer_letter)

            """if correct_answer_letter != model_answer_letter:
                print(entry["q_id"])"""

    # Calculate accuracy
    accuracy = accuracy_score(y_true, y_pred)
    # print(y_true)
    # print(y_pred)
    empty_string_count = sum(1 for element in y_pred if element == "")
    print(q_type, variable_type, variable_value,
          "All count: ", len(y_true), "Empty count: ", empty_string_count)
    # print(count_list)

    # Calculate F1 score
    # Use 'macro' for multi-class classification
    f1 = f1_score(y_true, y_pred, average='macro')

    return accuracy, f1


if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description='Run the evaluation script for a specified GPT output.')

    parser.add_argument('gpt_output_path', type=str,
                        help='The file for gpt_output.')
    parser.add_argument('--q_type', type=str, default="all",
                        help='The question type from ["all", "count", "existence"]')
    parser.add_argument('--q_component', type=str, default="all",
                        help='The question component.')

    args = parser.parse_args()
    accuracy, f1 = calculate_metrics(
        [args.gpt_output_path], args.q_type, args.q_component)
    print(f"Accuracy: {accuracy}")
    print(f"F1 Score: {f1}")
