#!/usr/bin/env python3
#
# Requirements:
#   * Python ≥3.8
#   * pandas
#   * MMseqs2 (in $PATH)  ≥15.6f452
#
# Usage:
#   python compute_test_vs_train_identity.py train.csv test.csv
#
# --------------------------------------------------------------------

import pandas as pd
import pathlib, subprocess, tempfile, shutil, statistics, sys

TRAIN_CSV = pathlib.Path(sys.argv[1])  # e.g. train.csv
TEST_CSV = pathlib.Path(sys.argv[2])  # e.g. test.csv


def csv_to_fasta(csv_file: pathlib.Path, fasta_file: pathlib.Path):
    """Write >name\nSEQRES\n records, stripping whitespace."""
    df = pd.read_csv(csv_file)
    with fasta_file.open("w") as fh:
        for _, row in df.iterrows():
            seq = str(row["seqres"]).replace(" ", "").replace("\n", "")
            fh.write(f">{row['name']}\n{seq}\n")


with tempfile.TemporaryDirectory() as tmpdir:
    tmp = pathlib.Path(tmpdir)
    train_fa = tmp / "train.fa"
    test_fa = tmp / "test.fa"

    # 1. CSV → FASTA
    csv_to_fasta(TRAIN_CSV, train_fa)
    csv_to_fasta(TEST_CSV, test_fa)

    # 2. MMseqs2: align every test against train (top hit only)
    result_m8 = tmp / "result.m8"
    mm_tmp = tmp / "mm_tmp"

    for name, fa in (("train", train_fa), ("test", test_fa)):
        count = sum(1 for line in fa.open() if line.startswith(">"))
        print(f"{name}.fa contains {count} sequences")

    n_train = len(pd.read_csv(TRAIN_CSV))

    cmd = [
        "mmseqs",
        "easy-search",
        str(test_fa),
        str(train_fa),
        str(result_m8),
        str(mm_tmp),
        # "--min-seq-id",
        # "0.0",  # no identity floor
        # "--cov-mode",
        # "1",  # coverage mode: query coverage
        # "-c",
        # "0",  # drop coverage cutoff entirely
        "-s",
        "8.0",  # high sensitivity
        # "--max-seqs",
        # str(n_train),  # allow up to all train sequences
        # "-e",
        # "0.001",
        "--format-output",
        "query,target,pident,alnlen",
    ]
    subprocess.run(cmd, check=True)

    print(
        train_fa,
        "vs",
        test_fa,
        "→",
        result_m8,
        "contains",
        sum(1 for _ in open(result_m8)) - 1,
        "lines",
    )

    # 3. Parse results, keep best hit per query with alnLen ≥ 20
    cols = ["query", "target", "pident", "alnlen"]
    res = pd.read_csv(result_m8, sep="\t", names=cols)
    print(res)
    res = res[res["alnlen"] >= 20]
    best = res.sort_values("pident", ascending=False).drop_duplicates(
        "query"
    )  # keep best hit

    # Any test protein without a qualifying hit gets NaN
    test_names = pd.read_csv(TEST_CSV)["name"]
    best = test_names.to_frame().merge(
        best[["query", "pident"]], left_on="name", right_on="query", how="left"
    )

    pid_values = best["pident"].dropna().tolist()
    if pid_values:
        mean_pid = statistics.mean(pid_values)
        median_pid = statistics.median(pid_values)
        std_pid = statistics.stdev(pid_values) if len(pid_values) > 1 else 0
        over_40 = sum(p >= 40.0 for p in pid_values)
    else:
        mean_pid = median_pid = std_pid = over_40 = 0

    # 4. Report
    print(f"--- Test ↔ Train sequence identity summary (MMseqs2, s=8.0) ---")
    print(f"Test proteins analysed     : {len(test_names)}")
    print(f"With ≥20-residue alignment : {len(pid_values)}")
    print(f"Mean of best-hit %ID       : {mean_pid:6.2f}")
    print(f"Median                     : {median_pid:6.2f}")
    print(f"Std-dev                    : {std_pid:6.2f}")
    print(
        f"≥ 40 % identity hits       : {over_40} "
        f"({over_40/len(test_names)*100:5.1f} %)"
    )
