import os
import json
import argparse
from tqdm import tqdm
from pathlib import Path
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from evaluation.evaluate import sample_match
from inference.synthesis_procedure_extraction import normalize_jsonld


def extract_doi_from_filename(filename):
    name = filename.replace(".json", "")
    doi = name.replace("_", "/")
    return doi


def load_jsons_from_dir(json_dir):
    json_dir = Path(json_dir)
    json_files = sorted(json_dir.glob("*.json"))
    data_dict = {}
    for json_file in json_files:
        with open(json_file, "r", encoding="utf-8") as fin:
            data = json.load(fin)
            data_dict[json_file.name] = data
    return data_dict


def main(input_dir, true_path, output_path):
    pred_data_dict = load_jsons_from_dir(input_dir)
    true_data_dict = load_jsons_from_dir(true_path)

    pred_files = set(pred_data_dict.keys())
    true_files = set(true_data_dict.keys())
    common_files = sorted(pred_files & true_files)
    pred_only_files = sorted(pred_files - true_files)

    print(
        f"Found {len(common_files)} matching files in prediction and ground truth directories."
    )
    print(f"Found {len(pred_only_files)} prediction files without ground truth.")

    with open(output_path, "w", encoding="utf-8") as fout:
        for fname in tqdm(common_files):
            pred_samples = pred_data_dict[fname]
            true_samples = true_data_dict[fname]
            matched_pred_samples, _ = sample_match(pred_samples, true_samples)
            for pred_entry in matched_pred_samples:
                label = pred_entry.get("label", "")
                prov_jsonld = {k: v for k, v in pred_entry.items() if k != "label"}
                normalized_prov_jsonld = normalize_jsonld([prov_jsonld])[0]
                doi = extract_doi_from_filename(fname)
                fout.write(
                    json.dumps(
                        {
                            "doi": doi,
                            "label": label,
                            "prov_jsonld": normalized_prov_jsonld,
                        },
                        ensure_ascii=False,
                    )
                    + "\n"
                )
        for fname in tqdm(pred_only_files):
            pred_samples = pred_data_dict[fname]
            for pred_entry in pred_samples:
                label = pred_entry.get("label", "")
                prov_jsonld = {k: v for k, v in pred_entry.items() if k != "label"}
                normalized_prov_jsonld = normalize_jsonld([prov_jsonld])[0]
                doi = extract_doi_from_filename(fname)
                fout.write(
                    json.dumps(
                        {
                            "doi": doi,
                            "label": label,
                            "prov_jsonld": normalized_prov_jsonld,
                        },
                        ensure_ascii=False,
                    )
                    + "\n"
                )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert generated JSON files to a single JSONL file."
    )
    parser.add_argument(
        "--input-dir",
        type=str,
        required=True,
        help="Directory containing generated JSON files",
    )
    parser.add_argument(
        "--true-path",
        type=str,
        required=True,
        help="Directory containing ground truth JSON files",
    )
    parser.add_argument(
        "--output", type=str, required=True, help="Output JSONL file path"
    )
    args = parser.parse_args()
    main(args.input_dir, args.true_path, args.output)
