import pandas as pd
from Bio import SeqIO
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--pairs", required=True, help="Path to topk_pairs.tsv")
parser.add_argument("--sequences", required=True, help="FASTA file with protein sequences (used for both p1 and p2)")
parser.add_argument("--out_dir", default="speedppi_batch_input/", help="Directory to save output FASTA and map")
args = parser.parse_args()

os.makedirs(args.out_dir, exist_ok=True)

# === Load sequence dictionary
seq_dict = {record.id: str(record.seq) for record in SeqIO.parse(args.sequences, "fasta")}

# === Load pairs
pairs_df = pd.read_csv(args.pairs, sep="\t", header=None, names=["protein1", "protein2"])

# === Prepare outputs
fasta_p1 = open(os.path.join(args.out_dir, "protein1.fasta"), "w")
fasta_p2 = open(os.path.join(args.out_dir, "protein2.fasta"), "w")
pair_map = []

missing = 0

for _, row in pairs_df.iterrows():
    p1 = row["protein1"]
    p2 = row["protein2"]

    if p1 not in seq_dict or p2 not in seq_dict:
        missing += 1
        continue

    # Write sequences using actual STRING IDs
    fasta_p1.write(f">{p1}\n{seq_dict[p1]}\n")
    fasta_p2.write(f">{p2}\n{seq_dict[p2]}\n")
    pair_map.append((p1, p2))

fasta_p1.close()
fasta_p2.close()

# Save pair map
pair_map_df = pd.DataFrame(pair_map, columns=["protein1_id", "protein2_id"])
pair_map_df.to_csv(os.path.join(args.out_dir, "pair_map.tsv"), sep="\t", index=False)

print(f"Saved {len(pair_map)} valid pairs to {args.out_dir}")
if missing > 0:
    print(f"Skipped {missing} pairs due to missing sequences")