import pandas as pd
from Bio import SeqIO
from config import *

# === STEP 1: read PPI and get protein IDs ===
df = pd.read_csv(ppi_file, sep="\t", usecols=[0, 1], names=["protein1", "protein2"], header=0)
protein_ids = set(df["protein1"]).union(set(df["protein2"]))
print(f"Total unique proteins in PPI: {len(protein_ids)}")

# === STEP 2: Filter FASTA v11 ===
saved_ids_v11 = set()
with open(filtered_v11, "w") as out_f:
    for record in SeqIO.parse(v11_fasta, "fasta"):
        if record.id in protein_ids:
            SeqIO.write(record, out_f, "fasta")
            saved_ids_v11.add(record.id)
print(f"Sequences found in v11: {len(saved_ids_v11)}")

# === STEP 3: Identify missing IDs from v11 ===
missing_ids = protein_ids - saved_ids_v11
print(f"Proteins missing from v11: {len(missing_ids)}")

# === STEP 4: Filter FASTA v12 only for missing ===
saved_ids_v12 = set()
with open(filtered_v12, "w") as out_f:
    for record in SeqIO.parse(v12_fasta, "fasta"):
        if record.id in missing_ids:
            SeqIO.write(record, out_f, "fasta")
            saved_ids_v12.add(record.id)
            print(f"Found {record.id} in v12")
print(f"Sequences recovered from v12: {len(saved_ids_v12)}")

# === STEP 5: Combine v11 + v12 into final unique FASTA ===
combined_records = list(SeqIO.parse(filtered_v11, "fasta")) + list(SeqIO.parse(filtered_v12, "fasta"))
SeqIO.write(combined_records, combined_fasta, "fasta")
print(f"Final combined FASTA saved: {combined_fasta}")
print(f"Total sequences in combined FASTA: {len(combined_records)}")

# === STEP 6: Restrict PPI to proteins present in combined FASTA ===
valid_ids = {rec.id for rec in SeqIO.parse(combined_fasta, "fasta")}
df_filtered = df[df["protein1"].isin(valid_ids) & df["protein2"].isin(valid_ids)]

print(f"Interactions kept after combined FASTA filter: {len(df_filtered)} "
      f"(removed {len(df) - len(df_filtered)})")

# Save filtered PPI 
ppi_filtered_out = os.path.join(os.path.dirname(combined_fasta), "human_complete_filtered.tsv")
df_filtered.to_csv(ppi_filtered_out, sep="\t", index=False)
print(f"Filtered PPI saved: {ppi_filtered_out}")