import os
import argparse
import pandas as pd
from Bio import SeqIO
from pathlib import Path

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", default="speedppi_fastas", help="Folder with protein1.fasta, protein2.fasta, pair_map.tsv")
parser.add_argument("--output_dir", default="speedppi_batches", help="Folder to store split batches")
parser.add_argument("--batch_size", type=int, default=500, help="Number of pairs per batch")
args = parser.parse_args()

# === Load inputs
p1_records = list(SeqIO.parse(os.path.join(args.input_dir, "protein1.fasta"), "fasta"))
p2_records = list(SeqIO.parse(os.path.join(args.input_dir, "protein2.fasta"), "fasta"))
pair_map = pd.read_csv(os.path.join(args.input_dir, "pair_map.tsv"), sep="\t")

assert len(p1_records) == len(p2_records) == len(pair_map), "Mismatch in input file lengths"

# === Split into batches
os.makedirs(args.output_dir, exist_ok=True)
num_batches = (len(pair_map) + args.batch_size - 1) // args.batch_size

for i in range(num_batches):
    batch_name = f"batch_{i+1:02d}"
    batch_dir = os.path.join(args.output_dir, batch_name)
    os.makedirs(batch_dir, exist_ok=True)

    start = i * args.batch_size
    end = min((i + 1) * args.batch_size, len(pair_map))

    # Write protein1.fasta
    SeqIO.write(p1_records[start:end], os.path.join(batch_dir, "protein1.fasta"), "fasta")
    # Write protein2.fasta
    SeqIO.write(p2_records[start:end], os.path.join(batch_dir, "protein2.fasta"), "fasta")
    # Write pair_map.tsv
    pair_map.iloc[start:end].to_csv(os.path.join(batch_dir, "pair_map.tsv"), sep="\t", index=False)

print(f"Created {num_batches} batches in: {args.output_dir}")