import os
import pandas as pd
from config_base import *

def evaluate_rediscovery():
    print("Evaluating retrieval performance...")

    # Load aliases
    alias_df = pd.read_csv(ALIAS_V11, sep="\t")
    id_to_name = dict(zip(alias_df["protein_external_id"], alias_df["preferred_name"]))
    name_to_id = dict(zip(alias_df["preferred_name"], alias_df["protein_external_id"]))

    # Load v11 protein universe
    v11_named = pd.read_csv(f"{OUTPUT_DIR}/v11_named.tsv", sep="\t")
    all_v11_proteins = set(v11_named["protein1"]).union(set(v11_named["protein2"]))

    rediscovery_stats = []
    skipped_proteins = []

    for p1_name in os.listdir(OUTPUT_DIR):
        p1_dir = os.path.join(OUTPUT_DIR, p1_name)
        if not os.path.isdir(p1_dir):
            continue

        top_sim_file = os.path.join(p1_dir, "top_similars.tsv")
        v12_target_file = os.path.join(p1_dir, "v12_targets.tsv")

        if not os.path.exists(v12_target_file):
            continue

        targets_df = pd.read_csv(v12_target_file, sep="\t")
        novel_p2 = set(targets_df["protein2_name"].dropna().unique())

        if not os.path.exists(top_sim_file):
            skipped_proteins.append({
                "protein1": p1_name,
                "# v12_targets": len(novel_p2),
                "# rediscovered": 0,
                "p2_in_v11": None,
                "p2_not_in_v11": None,
                "rediscovered_proteins": ""
            })
            continue

        similars_df = pd.read_csv(top_sim_file, sep="\t")
        predicted_similars = set(similars_df["similar_protein"].dropna().unique())

        # Convert similar_protein IDs to preferred names (if possible)
        predicted_names = {id_to_name[pid] for pid in predicted_similars if pid in id_to_name}

        rediscovered = novel_p2.intersection(predicted_names)

        # for each p1, count how many novel p2 are already in v11 (involved in other interactions)
        # and how many are not
        p2_in_v11 = 0
        p2_not_in_v11 = 0
        for p2 in novel_p2:
            p2_id = name_to_id.get(p2)
            if p2_id and p2_id in all_v11_proteins:
                p2_in_v11 += 1
            else:
                p2_not_in_v11 += 1

        current_stats = {
            "protein1": p1_name,
            "# v12_targets": len(novel_p2),
            "# rediscovered": len(rediscovered),
            "p2_in_v11": p2_in_v11,
            "p2_not_in_v11": p2_not_in_v11,
            "rediscovered_proteins": ", ".join(sorted(rediscovered))
        }
        print(f"Evaluating {p1_name}: {current_stats}")
        rediscovery_stats.append(current_stats)

    # Save results
    out_df = pd.DataFrame(rediscovery_stats)
    out_df.to_csv(SUMMARY_FILE, sep="\t", index=False)
    print(f"Saved summary: {SUMMARY_FILE}")

    if skipped_proteins:
        skipped_df = pd.DataFrame(skipped_proteins)
        skipped_path = f"{OUTPUT_DIR}/no_known_partners.tsv"
        skipped_df.to_csv(skipped_path, sep="\t", index=False)
        print(f"Saved proteins without known partners to: {skipped_path}")

if __name__ == "__main__":
    evaluate_rediscovery()
