import json
import os
import argparse

parser = argparse.ArgumentParser(description="Extract features from JSONL files.")
parser.add_argument("--results_file", type=str, required=True, help="Path to the input JSONL file.")
parser.add_argument("--threshold", type=float, required=True, help="Similarity threshold for exact match.")
parser.add_argument("--score_aggregation_mode", type=str, choices=["max", "avg"], default="max", help="Pooling mode for scores across multiple reference generations from same model.")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output features.")

args = parser.parse_args()

with open(args.results_file, "r") as fin:
	results_data = json.load(fin)

review2features = {
	"train": dict(),
	"test": dict(),
}

for key, val in results_data.items():

	if "/train/" in key:
		split = "train"
	elif "/test/" in key:
		split = "test"
	else:
		continue # although there should not be any dev examples in anchor features file, in case there are, we skip them for now

	ref_gen_model_wise_scores = dict()
	for ref_id in sorted(val.keys()):
		ref_gen_model = ref_id.split("@")[1]
		if ref_gen_model not in ref_gen_model_wise_scores:
			ref_gen_model_wise_scores[ref_gen_model] = []
		ref_gen_model_wise_scores[ref_gen_model].append(val[ref_id][str(args.threshold)])

	for ref_gen_model in sorted(ref_gen_model_wise_scores.keys()):
		if args.score_aggregation_mode == "max":
			aggregated_score = max(ref_gen_model_wise_scores[ref_gen_model])
		else:  # avg
			aggregated_score = sum(ref_gen_model_wise_scores[ref_gen_model]) / len(ref_gen_model_wise_scores[ref_gen_model])

		if key not in review2features[split]:
			review2features[split][key] = []
		review2features[split][key].append(aggregated_score)

# print(json.dumps(review2features["/Project/Human_or_AI/Data_Preprocessing/cleandata/acl_2017/test/gpt_4o_latest/level1/148_1.txt"], indent=4))

args.output_dir = os.path.join(args.output_dir, f"threshold={args.threshold}_aggregation={args.score_aggregation_mode}/")
os.makedirs(args.output_dir, exist_ok=True)

for split in ["train", "test"]:
	output_file = os.path.join(args.output_dir, f"X_{split}_threshold={args.threshold}_aggregation={args.score_aggregation_mode}.json")
	with open(output_file, "w") as fout:
		json.dump(review2features[split], fout, indent=4)

