import pandas as pd, os, argparse
from config import *

ap = argparse.ArgumentParser()
ap.add_argument("--rediscovery_file", default=os.path.join(OUTPUT_DIR, "rediscovery_detailed_ranks.tsv"))
ap.add_argument("--results_dir", default=RE_RANK_OUTPUT_DIR)
ap.add_argument("--alias_file", default=ALIAS_V11)
ap.add_argument("--out", default="final_results/topK_expansion.tsv")
ap.add_argument("--K", type=int, default=10)
args = ap.parse_args()

alias_df = pd.read_csv(args.alias_file, sep="\t")
id_to_name = dict(zip(alias_df["protein_external_id"], alias_df["preferred_name"]))

rediscovery_df = pd.read_csv(args.rediscovery_file, sep="\t")
filtered = rediscovery_df[rediscovery_df["best_rank"] <= args.K]

processed_pairs = set()
rows = []

for _, row in filtered.iterrows():
    p1 = row["protein1"]
    kp = row["best_known_partner"]
    if (p1, kp) in processed_pairs:
        continue
    path = os.path.join(args.results_dir, p1, "top_similars.tsv")
    if not os.path.exists(path):
        print(f"[WARN] Missing file: {path}")
        continue
    try:
        df_sim = pd.read_csv(path, sep="\t")
        df_kp = df_sim[df_sim["known_partner"] == kp].sort_values(
            "similarity_score", ascending=False
        ).head(args.K)
        df_kp["protein1"] = p1
        df_kp["best_known_partner"] = kp
        df_kp["similar_protein_name"] = df_kp["similar_protein"].map(id_to_name)

        rediscovered_set = set(
            filtered[(filtered["protein1"] == p1) & (filtered["best_known_partner"] == kp)]["rediscovered_protein"]
        )
        df_kp["rediscovered_flag"] = df_kp["similar_protein_name"].isin(rediscovered_set)
        rows.append(df_kp)
        processed_pairs.add((p1, kp))
    except Exception as e:
        print(f"ERROR {path}: {e}")

if rows:
    out = pd.concat(rows, ignore_index=True)[[
        "protein1","best_known_partner","known_partner",
        "similar_protein","similar_protein_name","similarity_score","rediscovered_flag"
    ]]
    out.to_csv(args.out, sep="\t", index=False)
    print(f"Saved: {args.out} (K={args.K}) | rows={len(out)}")
else:
    print("No valid rows.")