import os
import csv
import json
import argparse
from tqdm import tqdm


def get_data_idx(file_name):
    with open(file_name, "r") as f:
        data = json.load(f)
    idx_dict = {item['idx']: enum_idx for enum_idx, item in enumerate(data)}
    return data, idx_dict


def main(args):
    ### Loading datasets
    intersect_idx_path = os.path.join("intersect/intersect_idx.json")
    with open(intersect_idx_path, 'r') as file:
        intersect_idx_dict = json.load(file)
        idx_list = list(map(int, intersect_idx_dict.keys()))

    ### Loading annotations and solutions
    annotation_file = os.path.join("evaluate/raw/annotation.json")
    annotations, annotation_idx_dict = get_data_idx(annotation_file)
    solution_file = os.path.join(f"evaluate/raw/solution_{args.sol_name}.json")
    solutions, solution_idx_dict = get_data_idx(solution_file)

    ### Creating output directories
    comparison_dir = os.path.join(os.getcwd(), f"evaluate/comparison")
    if not os.path.exists(comparison_dir):
        os.makedirs(comparison_dir)
    output_file = os.path.join(comparison_dir, f"cmp_{args.sol_name}.csv")

    ### Verifying
    verifications = []
    types = ['Recognition', 'Understanding', 'Grounding', 'Reasoning']
    rev_map = {'A': 'D', 'B': 'C', 'C': 'B', 'D': 'A'}
    for cnt, idx in tqdm(enumerate(idx_list), total=len(idx_list), desc="Processing diagrams"):
        if idx not in annotation_idx_dict or idx not in solution_idx_dict:
            continue
        anno = annotations[annotation_idx_dict[idx]]['annotation']
        sol_1 = solutions[solution_idx_dict[idx]]['solution']['fold_1']
        sol_2 = solutions[solution_idx_dict[idx]]['solution']['fold_2']
        # Correct in both folds --> 1, else --> 0
        verifications.append((idx, [min(int(anno[t]['Answer'] == sol_1[t]), int(anno[t]['Answer'] == rev_map[sol_2[t]])) for t in types]))

    ### Saving results
    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['idx'] + types)
        for idx, v in verifications:
            writer.writerow([idx] + v)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--sol_name', type=str, default='empty_v2')
    args = parser.parse_args()

    main(args)
