import csv
import json
import os

import nltk

# results_file = "checkpoints/VIST-AMT/vist_amt_sample_results_easy_distractors.csv"
results_file = "checkpoints/VIST-AMT/vist_amt_sample_results_hard_seed_123.csv"
num_total = 0
num_correct = 0
if __name__ == "__main__":
    data = {}
    with open(results_file) as f:
        input_file = csv.DictReader(f)
        for row in input_file:
            gt_index = int(row["Input.gt_index"])
            num_total += 1
            num_correct += int(
                row["Answer.img-%d.img-%d" % (gt_index + 1, gt_index + 1)] == "true"
            )
            key = "-".join(row["Input.context_url_%d" % j] for j in range(1, 5))
            if key not in data:
                data[key] = []
            if "true" not in [
                row["Answer.img-%d.img-%d" % (j, j)] for j in range(1, 6)
            ]:
                continue
            data[key].append(
                [row["Answer.img-%d.img-%d" % (j, j)] for j in range(1, 6)].index(
                    "true"
                )
            )
    num_agreement_atleast_one = 0
    for sample in data:
        num_agreement_atleast_one += int(len(data[sample]) > len(set(data[sample])))

    print(
        "percentage of samples with overlap among annotators",
        num_agreement_atleast_one / len(data),
    )
    print("num total", num_total)
    print("num correct", num_correct)
    print("accuracy", num_correct / num_total)
    # agreement among annotations
