import pandas as pd
import os
from config import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--topk_file", default="final_results/top10_expansion.tsv")
parser.add_argument("--score_file", required=True, help="TSV file with (protein1_id, protein2_id, score)")
parser.add_argument("--score_name", required=True, help="Name of the score column to add, e.g., cosine")
parser.add_argument("--output_file", default=None, help="Output file (optional)")
parser.add_argument("--alias_file", default=ALIAS_V11)
args = parser.parse_args()

# === Load alias mapping ===
alias_df = pd.read_csv(args.alias_file, sep="\t")
name_to_id = dict(zip(alias_df["preferred_name"], alias_df["protein_external_id"]))

# === Load data ===
df_top = pd.read_csv(args.topk_file, sep="\t")
df_score = pd.read_csv(args.score_file, sep="\t", header=None, names=["protein1_id", "protein2_id", args.score_name])

# === Map protein1 → STRING ID
df_top["protein1_id"] = df_top["protein1"].map(name_to_id)

# === Merge scores using STRING IDs
merged = df_top.merge(
    df_score,
    how="left",
    left_on=["protein1_id", "similar_protein"],
    right_on=["protein1_id", "protein2_id"]
)

# === Drop duplicate columns
merged.drop(columns=["protein2_id"], inplace=True)

# === Rank by new score (e.g., IS_score), per protein1 + known_partner
merged[f"{args.score_name}_rank"] = (
    merged.groupby(["protein1", "known_partner"])[args.score_name]
    .rank(method="dense", ascending=False)
)

# === Rank by similarity_score, per protein1 + known_partner
merged["similarity_score_rank"] = (
    merged.groupby(["protein1", "known_partner"])["similarity_score"]
    .rank(method="dense", ascending=False)
)

# === Delta rank
merged[f"delta_rank_vs_similarity"] = (
    merged["similarity_score_rank"] - merged[f"{args.score_name}_rank"]
)

# === Reorder columns
column_order = [
    "protein1", "protein1_id", "best_known_partner", "known_partner",
    "similar_protein", "similar_protein_name",
    "similarity_score", args.score_name, "rediscovered_flag",
    "similarity_score_rank", f"{args.score_name}_rank", "delta_rank_vs_similarity"
]

final_columns = [col for col in column_order if col in merged.columns] + [
    col for col in merged.columns if col not in column_order
]

merged = merged[final_columns]

# === Stats on rediscovered_flag == True
rediscovered = merged[merged["rediscovered_flag"] == True]

improved = (rediscovered["delta_rank_vs_similarity"] > 0).sum()
worsened = (rediscovered["delta_rank_vs_similarity"] < 0).sum()
unchanged = (rediscovered["delta_rank_vs_similarity"] == 0).sum()
total = len(rediscovered)

print(f"\nRediscovered ranking shift ({args.score_name} vs similarity):")
print(f"  Improved:  {improved} / {total}")
print(f"  Worsened:  {worsened} / {total}")
print(f"  Unchanged: {unchanged} / {total}")

# === Save output
output_file = args.output_file or f"final_results/top10_ranked_by_{args.score_name}.tsv"
merged.to_csv(output_file, sep="\t", index=False)
print(f"Saved ranked file to: {output_file}")
