#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Preprocess GRAMPA (Witten & Witten 2019) and train Random Forest classifiers per pathogen
"""

import argparse, json, re, sys, math
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score
import joblib

def read_fasta(fp: Path):
    header = None
    seq_chunks = []
    with fp.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                if header is not None:
                    yield header, re.sub(r"[^A-Za-z]", "", "".join(seq_chunks)).upper()
                header = line[1:].strip()
                seq_chunks = []
            else:
                seq_chunks.append(line)
        if header is not None:
            yield header, re.sub(r"[^A-Za-z]", "", "".join(seq_chunks)).upper()

def detect_col(df: pd.DataFrame, names: List[str]) -> Optional[str]:
    lc = {c: c.lower() for c in df.columns}
    # exact
    for name in names:
        for c, l in lc.items():
            if l == name.lower():
                return c
    # contains
    for name in names:
        for c, l in lc.items():
            if name.lower() in l:
                return c
    return None

BAD_MOD_KEYWORDS = tuple([
    "acetyl","biotin","peg","palmit","dye","fluor","lipid","phospho","formyl","citryl","methyl",
    "glycosyl","label","dansyl","tamra","fitc","fura","rhodamine","nme","nme2"
])

def keep_valid_mods(mod_str: str) -> bool:
    if mod_str is None or (isinstance(mod_str, float) and math.isnan(mod_str)):
        return True
    s = str(mod_str).lower()
    return not any(bad in s for bad in BAD_MOD_KEYWORDS)

TARGETS = {
    "E.coli": r"(?:escherichia|e)\s+coli",
    "S.aureus": r"(?:staphylococcus|s)\s+aureus",
    "P.aeruginosa": r"(?:pseudomonas|p)\s+aeruginosa",
    "B.subtilis": r"(?:bacillus|b)\s+subtilis",
    "C.albicans": r"(?:candida|c)\s+albicans",
}
SHORT_TARGET = {
    "E.coli": "ecoli",
    "S.aureus": "saureus",
    "P.aeruginosa": "paeruginosa",
    "B.subtilis": "bsubtilis",
    "C.albicans": "calbicans",
}

def build_vocab(exclude_cysteine: bool = True) -> List[str]:
    aa20 = list("ACDEFGHIKLMNPQRSTVWY")
    aa = [a for a in aa20 if (a != "C" or not exclude_cysteine)]
    vocab = ["*"] + aa  # EOS primeiro agora
    return vocab

def encode_onehot(seq: str, vocab: List[str], max_len: int = 50) -> np.ndarray:
    tok = {t: i for i, t in enumerate(vocab)}
    V = len(vocab)
    X = np.zeros((max_len, V), dtype=np.float32)
    # EOS por símbolo, não por posição relativa no vocab
    eos = tok.get("*", 0)
    X[:, eos] = 1.0  # padding default com EOS
    L = min(len(seq), max_len)
    for i, ch in enumerate(seq[:L]):
        j = tok.get(ch)
        if j is not None:
            X[i, :] = 0.0
            X[i, j] = 1.0
    return X.reshape(-1)

def build_length_hist(seqs: List[str]) -> np.ndarray:
    lengths = [len(s) for s in seqs if s]
    vals, counts = np.unique(lengths, return_counts=True)
    probs = counts / counts.sum()
    return np.stack([vals, probs], axis=1)

def sample_negatives_random(n: int, length_hist: np.ndarray, vocab: List[str], rng: np.random.RandomState) -> List[str]:
    aa = [t for t in vocab if t != "*"]
    lens = rng.choice(length_hist[:,0].astype(int), size=n, p=length_hist[:,1])
    return ["".join(rng.choice(aa, size=L)) for L in lens]

def preprocess_grampa(grampa_csv: Path,
                      drop_yadamp: bool = True,
                      exclude_cysteine: bool = True,
                      min_len: int = 15,
                      max_len: int = 50) -> pd.DataFrame:
    df = pd.read_csv(grampa_csv, low_memory=False)

    seq_col = detect_col(df, ["sequence","seq","peptide"])
    bact_col = detect_col(df, ["bacterium","bacteria","organism","species","target"])
    mods_col = detect_col(df, ["modifications","mods","mod","modification"])
    src_col  = detect_col(df, ["database","source_db","source","db"])
    mic_col  = detect_col(df, ["mic","mic_um","mic (um)","mic_um","value"])
    unit_col = detect_col(df, ["unit"])

    if seq_col is None or bact_col is None:
        raise RuntimeError("Could not detect sequence/bacterium columns in GRAMPA CSV.")

    df[seq_col] = df[seq_col].astype(str).str.upper().str.replace(r"[^A-Z]", "", regex=True)
    df[bact_col] = df[bact_col].astype(str)

    if mods_col is not None:
        df = df[df[mods_col].map(keep_valid_mods)]
    if drop_yadamp and src_col is not None:
        df = df[~df[src_col].astype(str).str.lower().eq("yadamp")]

    df["len"] = df[seq_col].str.len()
    df = df[(df["len"] >= min_len) & (df["len"] <= max_len)]
    if exclude_cysteine:
        df = df[~df[seq_col].str.contains("C")]

    # Optional MIC processing
    if mic_col is not None:
        mic = pd.to_numeric(df[mic_col], errors="coerce").astype(float)
        if unit_col is not None:
            unit = df[unit_col].astype(str).str.lower()
            # convert common units to uM if needed (very rough; refine if necessary)
            # assume 'um', 'µm' -> already; 'mm' -> *1000; 'nm' -> /1000
            mm_mask = unit.str.contains("mm")
            nm_mask = unit.str.contains("nm")
            mic = mic.copy()
            mic[mm_mask] = mic[mm_mask] * 1000.0
            mic[nm_mask] = mic[nm_mask] / 1000.0
        df = df.assign(MIC_uM=mic)
        # drop non-finite
        df = df[np.isfinite(df["MIC_uM"])]
    # Aggregate (sequence, bacterium) by geometric mean if MIC present
    if "MIC_uM" in df.columns:
        g = (df.groupby([seq_col, bact_col], as_index=False)
               .agg(MIC_uM=("MIC_uM", lambda x: float(np.exp(np.mean(np.log(x)))))))
        df = g.rename(columns={seq_col:"sequence", bact_col:"bacterium"})
    else:
        df = df.drop_duplicates(subset=[seq_col, bact_col])[[seq_col, bact_col]]
        df = df.rename(columns={seq_col:"sequence", bact_col:"bacterium"})

    df["bacterium_norm"] = df["bacterium"].str.lower().str.replace(r"[^a-z ]"," ", regex=True).str.replace(r"\s+"," ", regex=True).str.strip()
    return df

def build_pos_sets(df: pd.DataFrame) -> Dict[str, List[str]]:
    pos = {}
    for key, pat in TARGETS.items():
        mask = df["bacterium_norm"].str.contains(pat, na=False, regex=True)
        seqs = df.loc[mask, "sequence"].drop_duplicates().tolist()
        pos[key] = seqs
    return pos

def train_rf_for_target(X: np.ndarray, y: np.ndarray, seed: int = 42):
    clf = RandomForestClassifier(
        n_estimators=500,
        max_depth=None,
        max_features="sqrt",
        min_samples_leaf=1,
        n_jobs=-1,
        class_weight="balanced",
        random_state=seed
    )
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    aucs = cross_val_score(clf, X, y, cv=cv, scoring="roc_auc", n_jobs=-1)
    clf.fit(X, y)
    return clf, float(np.mean(aucs))

def prepare_dataset(pos: List[str], neg: List[str], vocab: List[str], max_len: int, seed: int):
    rng = np.random.RandomState(seed)
    # balance
    if len(neg) < len(pos):
        reps = int(np.ceil(len(pos)/max(1,len(neg))))
        neg = (neg * reps)[:len(pos)]
    else:
        neg = list(rng.choice(neg, size=len(pos), replace=False))
    X = np.stack([encode_onehot(s, vocab, max_len) for s in (pos+neg)])
    y = np.array([1]*len(pos) + [0]*len(neg), dtype=np.int64)
    return X, y

def main():
    p = argparse.ArgumentParser(description="Preprocess GRAMPA and train RF classifiers per pathogen (DyNA-PPO AMP setup).")
    p.add_argument("--grampa_csv", type=str, default="grampa.csv", help="Path to GRAMPA CSV (e.g., grampa.csv)")
    p.add_argument("--out_dir", type=str, default="./rf_models", help="Output directory for models and artifacts")
    p.add_argument("--negatives", type=str, choices=["random","uniprot"], default="random",
                   help="How to generate negatives (random: fallback; uniprot: substrings from provided FASTA)")
    p.add_argument("--uniprot_fasta", type=str, default=None, help="Path to prefiltered UniProt FASTA (non-AMP cytosolic proteins)")
    p.add_argument("--exclude_cysteine", action="store_true", default=True, help="Exclude sequences containing Cysteine (C)")
    p.add_argument("--min_len", type=int, default=1, help="Minimum peptide length to keep (inclusive)")
    p.add_argument("--max_len", type=int, default=10, help="Maximum peptide length to keep (inclusive)")
    p.add_argument("--seed", type=int, default=42)
    args = p.parse_args()

    out_dir = Path(args.out_dir); (out_dir/"models").mkdir(parents=True, exist_ok=True); (out_dir/"encoders").mkdir(parents=True, exist_ok=True)

    print(">> Loading and preprocessing GRAMPA ...")
    df = preprocess_grampa(Path(args.grampa_csv), drop_yadamp=True, exclude_cysteine=args.exclude_cysteine, min_len=args.min_len, max_len=args.max_len)

    pos_sets = build_pos_sets(df)
    print(">> Target record counts after filtering:")
    for k, pat in TARGETS.items():
        cnt = df["bacterium_norm"].str.contains(pat, na=False, regex=True).sum()
        print(f"   - {k}: {cnt} rows")

    # build length histogram from union of positives
    union_pos = sorted(set([s for seqs in pos_sets.values() for s in seqs]))
    if not union_pos:
        print("No positives found for any target. Check your CSV columns and TARGETS patterns.", file=sys.stderr)
        sys.exit(1)
    length_hist = build_length_hist(union_pos)

    vocab = build_vocab(exclude_cysteine=args.exclude_cysteine)
    (out_dir / "encoders" / "vocab.json").write_text(json.dumps({"seq_size": args.max_len, "tokens": vocab, "eos": "*", "eos_index": vocab.index("*")}, indent=2))
    rng = np.random.RandomState(args.seed)

    report = []
    for target, pos in pos_sets.items():
        if not pos:
            print(f"[{target}] No positives; skipping.")
            continue
        neg = sample_negatives_random(len(pos), length_hist, vocab, rng)
        X,y = prepare_dataset(pos, neg, vocab, max_len=args.max_len, seed=args.seed)
        print(f"[{target}] Pos={len(pos)} Neg={len(neg)} X={X.shape}")
        model, auc = train_rf_for_target(X, y, seed=args.seed)
        short = SHORT_TARGET[target]
        joblib.dump({"model": model, "vocab": vocab, "max_len": args.max_len, "target": target, "short_target": short}, out_dir/"models"/f"rf_{short}.joblib")
        report.append({"target": target, "short": short, "cv_auc": auc, "n_pos": len(pos), "n_neg": len(neg)})

    (out_dir/"manifest.json").write_text(json.dumps({"report": report}, indent=2))
    print("\n=== Summary ===")
    for r in report:
        print(f"{r['target']:>12s} | AUC={r['cv_auc']:.3f} | Pos={r['n_pos']:>5d} Neg={r['n_neg']:>5d}")
    print(f"\nArtifacts written to: {out_dir.resolve()}")

if __name__ == "__main__":
    main()
