import os
import json
import glob
import argparse
from collections import defaultdict


def load_json_files(folder_path):
    files = sorted(glob.glob(os.path.join(folder_path, "*.json")))
    return files


def process_files(args):
    json_files = load_json_files(args.folder_path)
    response_matrix = []
    difficulty_stats = defaultdict(lambda: {"correct": 0, "total": 0})

    for file_path in json_files:
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        subject_id = os.path.splitext(os.path.basename(file_path))[0]
        responses = {f"item_{i + 1}": 0 for i in range(2000)}
        correct_count = 0

        for entry in data:
            difficulty = entry.get("difficulty")
            if difficulty is not None:
                is_correct_for_difficulty = entry.get("check") == "right"
                difficulty_stats[difficulty]["total"] += 1
                if is_correct_for_difficulty:
                    difficulty_stats[difficulty]["correct"] += 1

            item_idx = entry["id"]
            is_correct = entry["check"] == "right"
            if f"item_{item_idx}" in responses:
                responses[f"item_{item_idx}"] = int(is_correct)
            if is_correct:
                correct_count += 1

        total_items_in_file = len(data)
        accuracy = correct_count / total_items_in_file if total_items_in_file > 0 else 0
        print(f"{subject_id} accuracy: {accuracy:.2%}")

        response_matrix.append(
            {
                "subject_id": subject_id,
                "responses": responses,
            }
        )

    print("\n" + "=" * 50)
    print("Accuracy by Difficulty Level")
    print("=" * 50)

    sorted_difficulties = sorted(difficulty_stats.keys())

    for difficulty in sorted_difficulties:
        stats = difficulty_stats[difficulty]
        total = stats["total"]
        correct = stats["correct"]
        accuracy = correct / total if total > 0 else 0
        print(f"Difficulty {difficulty}: {correct}/{total} correct, Accuracy: {accuracy:.2%}")

    with open(args.output_jsonl_path, "w", encoding="utf-8") as out_file:
        for record in response_matrix:
            out_file.write(json.dumps(record) + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--folder-path", default=os.environ.get("DATA_FOLDER", ""))
    parser.add_argument("--output-jsonl-path", default=os.environ.get("OUTPUT_JSONL", "response_matrix.jsonl"))
    args = parser.parse_args()

    if not args.folder_path or not os.path.isdir(args.folder_path):
        print(f"Error: folder path does not exist -> {args.folder_path}")
    else:
        process_files(args)