import json
import os
import math
from collections import defaultdict
val_ids = os.listdir("./repos_verified")
val_ids += os.listdir("./repos_pro")
def compute_metrics(jsonl_path):
    def dice(a, b):
        if not a and not b: return 1.0
        if not a or not b: return 0.0
        inter = len(a & b)
        return 2 * inter / (len(a) + len(b))
    def iou(a, b):
        if not a and not b: return 1.0
        if not a or not b: return 0.0
        inter = len(a & b)
        union = len(a | b)
        return inter / union
    def hit(a, b):
        if not a and not b: return 1.0
        if not a or not b: return 0.0
        inter = len(a & b)
        precision = inter / len(b)
        return precision
    def recall(a, b):
        if not a and not b: return 1.0
        if not a or not b: return 0.0
        inter = len(a & b)
        recall = inter / len(a)
        return recall
    def recall_at_1(gt, preds):
        if not gt: return 0.0
        first_pred = list(preds)[0] if preds else None
        return 1.0 if first_pred in gt else 0.0
    def recall_at_1(gt, preds):
        if not gt: return 0.0
        first_pred = list(preds)[0] if preds else None
        return 1.0 if first_pred in gt else 0.0
    def hit_at_3(gt, preds):
        if not gt:
            return 0.0
        gt = set(gt) if not isinstance(gt, set) else gt
        top3 = set(list(preds)[:3]) if preds else set()
        hits = len(gt & top3)
        return hits / len(gt)
    def hit_at_5(gt, preds):
        if not gt:
            return 0.0
        gt = set(gt) if not isinstance(gt, set) else gt
        top5 = set(list(preds)[:5]) if preds else set()
        hits = len(gt & top5)
        return hits / len(gt)
    def ndcg_at_k(gt, preds, k):
        if not gt:
            return 0.0
        if not preds:
            return 0.0
        preds = list(preds)[:k]
        dcg = 0.0
        for i, p in enumerate(preds):
            if p in gt:
                dcg += 1.0 / math.log2(i + 2)
        ideal_hits = min(len(gt), k)
        idcg = sum(1.0 / math.log2(i + 2) for i in range(ideal_hits))
        return dcg / idcg if idcg > 0 else 0.0
    repo_rounds = defaultdict(list)
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if not line.strip():
                continue
            data = json.loads(line)
            instance_id = data.get("repo_id")
            if instance_id not in val_ids:
                continue
            repo_rounds[instance_id].append(data)
    round_results = defaultdict(lambda: {
        "file_recall": [],
        "file_precision": [],
        "file_dice": [],
        "file_hit": [],
        "file_iou": [],
        "file_r1": [],
        "file_r3": [],
        "file_r5": [],
        "func_precision": [],
        "file_ndcg": [],
        "func_recall": [],
        "func_dice": [],
        "func_hit": [],
        "func_iou": [],
        "func_r1": [],
        "func_r3": [],
        "func_r5": [],
        "func_ndcg": [],
    })
    for instance_id, rounds in repo_rounds.items():
        for round_idx, data in enumerate(rounds):
            gt_list = [x.strip() for x in (data.get("gt", "") or "").split(",") if (x or "").strip()]
            pred_list = [x.strip() for x in (data.get("pred", "") or "").split(",") if (x or "").strip()]
            gt_func_set = set(gt_list)
            pred_func_set = set(pred_list)
            round_results[round_idx]["func_dice"].append(dice(gt_func_set, pred_func_set))
            round_results[round_idx]["func_precision"].append(hit(gt_func_set, pred_func_set))
            round_results[round_idx]["func_recall"].append(recall(gt_func_set, pred_func_set))
            round_results[round_idx]["func_iou"].append(iou(gt_func_set, pred_func_set))
            round_results[round_idx]["func_r1"].append(recall_at_1(gt_func_set, pred_list))
            round_results[round_idx]["func_r3"].append(hit_at_3(gt_func_set, pred_list))
            round_results[round_idx]["func_r5"].append(hit_at_5(gt_func_set, pred_list))
            round_results[round_idx]["func_ndcg"].append(ndcg_at_k(gt_func_set, pred_list, 5))
            gt_file_set = {x.split("::")[0] for x in gt_list}
            pred_file_list = [x.split("::")[0] for x in pred_list]
            pred_file_set = set(pred_file_list)
            round_results[round_idx]["file_dice"].append(dice(gt_file_set, pred_file_set))
            round_results[round_idx]["file_precision"].append(hit(gt_file_set, pred_file_set))
            round_results[round_idx]["file_recall"].append(recall(gt_file_set, pred_file_set))
            round_results[round_idx]["file_iou"].append(iou(gt_file_set, pred_file_set))
            round_results[round_idx]["file_r1"].append(recall_at_1(gt_file_set, pred_file_list))
            round_results[round_idx]["file_r3"].append(hit_at_3(gt_file_set, pred_file_list))
            round_results[round_idx]["file_r5"].append(hit_at_5(gt_file_set, pred_file_list))
            round_results[round_idx]["file_ndcg"].append(ndcg_at_k(gt_file_set, pred_file_list, 5))
    def avg(lst):
        return sum(lst) / len(lst) if lst else 0.0
    summary = {}
    for round_idx, metrics in sorted(round_results.items()):
        print(len(metrics["file_dice"]))
        summary[f"round_{round_idx+1}"] = {
            "function_level": {
                "recall": avg(metrics["func_recall"]) * 100,
                "precision": avg(metrics["func_precision"]) * 100,
                "dice": avg(metrics["func_dice"]) * 100,
                "iou": avg(metrics["func_iou"]) * 100,
            },
            "file_level": {
                "recall": avg(metrics["file_recall"]) * 100,
                "precision": avg(metrics["file_precision"]) * 100,
                "dice": avg(metrics["file_dice"]) * 100,
                "iou": avg(metrics["file_iou"]) * 100,
            },
        }
    print(len(summary))
    return summary
if __name__ == '__main__':
    result_path = 'input-your-log'
    result_dict = compute_metrics(result_path)
    for round_name in result_dict.keys():
        print("=" * 100)
        print(f"{round_name.upper()}")
        for level in result_dict[round_name]:
            print(f"  Level: {level}")
            for metric, val in result_dict[round_name][level].items():
                print(f"    {metric}: {val:.2f}%")