"""
Leakage‑free train/test split for protein sequences by clustering with MMseqs2.

    python split_train_test_by_mmseqs_cluster.py \
        all.csv train.csv test.csv \
        --test-frac 0.10       # ≈10 % for test (default)
        --min-seq-id 0.30      # identity threshold used for clustering (default)
        --cov 0.80             # coverage (default)
        --seed 42              # RNG seed (optional; makes split reproducible)

"""

from __future__ import annotations
import argparse
import random
import shutil
import os
import subprocess
import sys
import tempfile
from pathlib import Path

import pandas as pd


def csv_to_fasta(csv_file: Path, fasta_file: Path) -> None:
    """Write a FASTA file (>name\nSEQRES) from a CSV with columns name, seqres."""
    df = pd.read_csv(csv_file)
    with fasta_file.open("w") as fh:
        for _, row in df.iterrows():
            seq = str(row["seqres"]).replace(" ", "").replace("\n", "")
            fh.write(f">{row['name']}\n{seq}\n")


def run_mmseqs_cluster(
    fasta_in: Path,
    work_dir: Path,
    min_seq_id: float,
    cov: float,
) -> Path:
    """
    Run MMseqs2 easy‑cluster and return the path to the *_cluster.tsv mapping file.
    """
    out_prefix = work_dir / "clusters"
    tmp = work_dir / "mm_tmp"

    cmd = [
        "mmseqs",
        "easy-cluster",
        str(fasta_in),
        str(out_prefix),
        str(tmp),
        "--min-seq-id",
        str(0.25),
        "-c",
        str(0.5),
        "--cov-mode",
        "1",
        "-s",
        str(7.5),
    ]
    subprocess.run(cmd, check=True)
    return out_prefix.with_suffix("_cluster.tsv")


def build_cluster_dict(cluster_tsv: Path) -> dict[str, list[str]]:
    """
    Read the two‑column rep→member table into {rep: [members]}.
    """
    clu = pd.read_csv(cluster_tsv, sep="\t", names=["rep", "member"])
    print("clu", clu)
    return clu.groupby("rep")["member"].apply(list).to_dict()


def sample_clusters(
    clusters: dict[str, list[str]],
    target_n: int,
    rng: random.Random,
) -> set[str]:
    """
    Randomly add whole clusters until the running total ≥ target_n.
    Returns the set of sequence IDs selected.
    """
    reps = list(clusters.keys())
    rng.shuffle(reps)

    selected: set[str] = set()
    total = 0
    for rep in reps:
        if total >= target_n:
            break
        members = clusters[rep]
        selected.update(members)
        total += len(members)
    return selected


# ------------------------------------------------------------ CLI
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Cluster‑based train/test split")
    p.add_argument("all_csv", type=Path, help="CSV with all sequences")
    p.add_argument("train_csv", type=Path, help="output: train CSV")
    p.add_argument("test_csv", type=Path, help="output: test CSV")
    p.add_argument("--test-frac", type=float, default=0.10, help="~fraction for test")
    p.add_argument("--min-seq-id", type=float, default=0.30, help="identity threshold")
    p.add_argument("--cov", type=float, default=0.80, help="coverage threshold")
    p.add_argument(
        "--seed", type=int, default=None, help="random‑seed for reproducibility"
    )
    return p.parse_args(argv)


def main() -> None:
    args = parse_args()

    rng = random.Random(args.seed)

    with tempfile.TemporaryDirectory() as td:
        tmp = Path(td)

        # 1  CSV → FASTA
        all_fa = tmp / "all.fa"
        csv_to_fasta(args.all_csv, all_fa)

        # 2  clustering
        cluster_tsv = run_mmseqs_cluster(all_fa, tmp, args.min_seq_id, args.cov)

        # 3  sample clusters
        clusters = build_cluster_dict(cluster_tsv)
        n_total = sum(len(m) for m in clusters.values())
        target_test = int(n_total * args.test_frac)
        test_ids = sample_clusters(clusters, target_test, rng)

        # 4  write CSVs
        df_all = pd.read_csv(args.all_csv).set_index("name")

        test_df = df_all.loc[test_ids]
        train_df = df_all.drop(test_ids)

        args.train_csv.parent.mkdir(parents=True, exist_ok=True)
        train_df.to_csv(args.train_csv, index=True)
        test_df.to_csv(args.test_csv, index=True)

    # 5  summary
    print(
        f"Wrote {len(train_df):,} sequences to {args.train_csv} "
        f"and {len(test_df):,} to {args.test_csv} "
        f"({len(test_df)/len(df_all):.1%} in test)"
    )


if __name__ == "__main__":
    try:
        main()
    except subprocess.CalledProcessError as e:
        sys.exit(f"ERROR: MMseqs2 command failed ({e})")
