import string
import argparse
import statistics

def load_prompts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')
            
            # Check if the last character is punctuation, and replace with question mark
            if prompt and prompt[-1] in string.punctuation:
                prompt = prompt.rstrip(string.punctuation) + "?"
            elif prompt:
                prompt += "?"

            prompts.append(prompt)
    return prompts


def load_predicts_from_txt(file_path):
    prompts = []
    with open(file_path, "r") as f:
        for line in f:
            # Strip the surrounding double quotes and any leading/trailing whitespace
            prompt = line.strip().strip('"')

            prompts.append(prompt)
    return prompts


def sanity_file():
    expected = load_prompts_from_txt("data/original.txt")
    with open("labels/predicted_original.txt", "w") as file:
        for label in expected:
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")
            file.write(f"{label}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--in_data",
        type=str,
        default='',
        help="The input dataset.",
    )
    parser.add_argument(
        "--out_data",
        type=str,
        default='',
        help="The output dataset.",
    )
    args = parser.parse_args()
    
    expected = load_prompts_from_txt(f"{args.in_data}.txt")
    real = load_predicts_from_txt(f"labels/predicted_{args.out_data}.txt")
    
    matching_lines = 0
    total_lines = 0
    correct = []
    for i in range(len(real)):
        expecting = expected[i // 5]
        real_out = real[i]
        total_lines += 1
        if expecting.strip() == real_out.strip():
            matching_lines += 1
            correct.append(i)
    
    accuracy = (matching_lines / total_lines) * 100

    print(accuracy)