import pandas as pd
import os
from config import *

def extract_known_partners():
    print("Loading data...")

    # Load v11 interactions with names
    v11 = pd.read_csv(f"{OUTPUT_DIR}/v11_named.tsv", sep="\t")

    # Load v12 IS predictions
    v12 = pd.read_csv(OUTPUT_COMBINED, sep="\t", header=None, names=["protein1", "protein2", "label"])

    # Load alias mapping for v12
    aliases_v12 = pd.read_csv(ALIAS_V12, sep="\t")
    id_to_name = dict(zip(aliases_v12["#string_protein_id"], aliases_v12["preferred_name"]))
    v12["protein1_name"] = v12["protein1"].map(id_to_name)
    v12["protein2_name"] = v12["protein2"].map(id_to_name)

    # Load candidate protein list
    top_proteins = pd.read_csv(TOP_PROTEINS, sep="\t").head(N_CANDIDATES)

    # Set of all proteins in v11
    all_v11_proteins = set(v11["protein1_name"]).union(set(v11["protein2_name"]))

    # Process each candidate protein1
    for i, row in top_proteins.iterrows():
        p1_id = row["protein1"]
        p1_name = row["protein1_name"]

        print(f"[{i+1}/{N_CANDIDATES}] Processing: {p1_name}")

        # Get known partners of p1 in v11
        known = v11[(v11["protein1_name"] == p1_name) | (v11["protein2_name"] == p1_name)]
        partners = set(known["protein1_name"]) | set(known["protein2_name"])
        partners.discard(p1_name)

        # Get new v12 targets (IS ≥ threshold and not a known partner)
        novel_targets = v12[v12["protein1_name"] == p1_name]
        #novel_targets = novel_targets[novel_targets["label"] >= IS_THRESHOLD]
        novel_targets = novel_targets[~novel_targets["protein2_name"].isin(partners)]

        # Check if protein2 appears anywhere in v11
        novel_targets["p2_in_v11"] = novel_targets["protein2_name"].apply(lambda x: x in all_v11_proteins)

        # Save output
        out_dir = f"{OUTPUT_DIR}/{p1_name}"
        os.makedirs(out_dir, exist_ok=True)

        pd.DataFrame({"known_partner": sorted(partners)}).to_csv(f"{out_dir}/v11_partners.tsv", sep="\t", index=False)
        novel_targets.to_csv(f"{out_dir}/v12_targets.tsv", sep="\t", index=False)
        print(novel_targets.head())

        # split into known/novel p2
        novel_targets[novel_targets["p2_in_v11"]].to_csv(f"{out_dir}/v12_targets_known_p2.tsv", sep="\t", index=False)
        novel_targets[~novel_targets["p2_in_v11"]].to_csv(f"{out_dir}/v12_targets_novel_p2.tsv", sep="\t", index=False)

    print("Extraction complete.")