import string
import argparse
import statistics

def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def load_predicts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')

            prompts.append(prompt)
    return prompts


def sanity_file():
    expected = load_prompts_from_txt("data/original.txt")
    with open("labels/predicted_original.txt", "w") as file:
        for label in expected:
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")


if __name__ == "__main__":
    # sanity_file()
    # expected = load_prompts_from_txt("data/original.txt")
    # real = load_predicts_from_txt(f"labels/predicted_original.txt")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--kind",
        type=str,
        required=True,
        help="The trace kinds",
    )
    parser.add_argument(
        "--Inkind",
        type=str,
        required=True,
        help="The trace Input kinds",
    )
    args = parser.parse_args()
    expected = load_prompts_from_txt(f"input_data/{args.Inkind}.txt")
    # kind = "semantics_sim"
    # kind = "structural_sim"
    # kind = "original"
    real = load_predicts_from_txt(f"labels/predicted_{args.kind}.txt")
    
    matching_lines = 0
    total_lines = 0
    correct = []
    for i in range(len(real)):
        expecting = expected[i // 5]
        real_out = real[i]
        total_lines += 1
        if expecting.strip() == real_out.strip():
            matching_lines += 1
            correct.append(i)
    
    accuracy = (matching_lines / total_lines) * 100

    print(accuracy)
    # print(correct)

    # chunk_size = 5
    # accuracies = []

    # for i in range(0, len(real), chunk_size):
    #     expecting = expected[i // chunk_size]
    #     matching_lines = 0
        
    #     # Check for accuracy in each chunk of 5
    #     for j in range(chunk_size):
    #         real_out = real[i + j]
    #         if expecting.strip() == real_out.strip():
    #             matching_lines += 1
        
    #     # Calculate accuracy for this chunk
    #     chunk_accuracy = (matching_lines / chunk_size) * 100
    #     accuracies.append(chunk_accuracy)

    # Calculate the max and min accuracy
    # max_accuracy = max(accuracies)
    # min_accuracy = min(accuracies)

    # print(f"Max Accuracy: {max_accuracy}%")
    # print(f"Min Accuracy: {min_accuracy}%")

    # mean_accuracy = statistics.mean(accuracies)
    # std_deviation = statistics.stdev(accuracies)

    # print(f"Mean Accuracy: {mean_accuracy}%")
    # print(f"Standard Deviation: {std_deviation}%")

    # [16, 19, 21, 22, 24, 31, 32, 35, 36, 37, 38, 39, 45, 48, 78, 100, 102, 105, 110, 118, 119, 132, 138, 148, 150, 176, 184, 204, 205, 209, 211, 212, 213, 220, 224, 231, 232, 233, 234, 241, 242, 244]