# peptide/plot_kmer_embedding.py
from __future__ import annotations

import argparse
import json
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import math 

import numpy as np
import torch

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
from matplotlib.lines import Line2D

from .peptide_env import Sequences, Policy
from .peptide_reward import LogReward

try:
    from tqdm import tqdm  # type: ignore
except Exception:
    tqdm = None

# Optional (UMAP)
try:
    import umap  # type: ignore
except Exception:
    umap = None

# Optional (sklearn PCA + KMeans)
try:
    from sklearn.decomposition import PCA  # type: ignore
except Exception:
    PCA = None

# Optional (sklearn KMeans)
try:
    from sklearn.cluster import KMeans  # type: ignore
except Exception:
    KMeans = None

# Optional RDKit
try:
    from rdkit import Chem  # type: ignore
    from rdkit.Chem import Draw  # type: ignore
    _HAS_RDKIT = True
except Exception:
    Chem = None
    Draw = None
    _HAS_RDKIT = False

# Optional joblib (RF models)
try:
    import joblib  # type: ignore
except Exception:
    joblib = None


# -----------------------------
# Matplotlib preamble (igual ao export_icml_figs_simple.py)
# -----------------------------
matplotlib.rcParams.update({
    "font.family": "serif",
    "font.size": 14.0,
    "lines.linewidth": 2,
    "lines.antialiased": True,
    "axes.facecolor": "fdfdfd",
    "axes.edgecolor": "777777",
    "axes.linewidth": 1,
    "axes.titlesize": "medium",
    "axes.labelsize": "medium",
    "axes.axisbelow": True,
    "xtick.color": "333333",
    "xtick.labelsize": "medium",
    "xtick.direction": "in",
    "ytick.major.size": 0,
    "ytick.minor.size": 0,
    "ytick.major.pad": 6,
    "ytick.minor.pad": 6,
    "ytick.color": "333333",
    "ytick.labelsize": "medium",
    "ytick.direction": "in",
    "axes.grid": True,
    "grid.alpha": 0.3,
    "grid.linewidth": 1,
    "legend.fancybox": True,
    "legend.fontsize": "Small",
    "figure.facecolor": "1.0",
    "figure.edgecolor": "0.5",
    "hatch.linewidth": 0.1,
    "text.usetex": True,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})
plt.rcParams["text.latex.preamble"] = r"\usepackage{times, amsmath, amssymb}"


def my_formatter(x, pos):
    val_str = "{:g}".format(x)
    if np.abs(x) > 0 and np.abs(x) < 1:
        return val_str.replace("0", "", 1)
    return val_str


major_formatter = FuncFormatter(my_formatter)


def apply_format(ax: plt.Axes):
    ax.xaxis.set_major_formatter(major_formatter)
    ax.yaxis.set_major_formatter(major_formatter)


# -----------------------------
# Names + colors (mesma nomenclatura / estilo)
# -----------------------------
field_to_name = {
    "tb": r"Trajectory Balance\\(Malkin et al., NeurIPS 2022)",
    "dtb": r"Divergent (\textbf{Ours})",
    "teacher_student": r"Adaptive Teacher\\(Kim et al., ICLR 2025)",
    "sa": r"Sibling Augmented\\(Madan et al., ICLR 2025)",
}

field_to_color = {
    "tb": "#b22222",
    "dtb": "#1f77b4",
    "teacher_student": "#911eb4",
    "sa": "#f58231",
}

# RF group names (ordem fixa) + cores para grupos
RF_GROUP_NAMES = ["E. coli", "S. aureus", "P. aeruginosa", "B. subtilis", "C. albicans"]

RF_GROUP_COLORS = [
    "#1f77b4",  # blue
    "#d62728",  # red
    "#2ca02c",  # green
    "#7f7f7f",  # gray
    "#ff7f0e",  # orange
]


# -----------------------
# IO helpers
# -----------------------
def read_json(path: Path) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def parse_seed_arg(seed_arg: str) -> str:
    s = seed_arg.strip()
    if s.startswith("seed_"):
        return s
    return f"seed_{int(s)}"


def list_methods(root: Path, run_id: str, seed_dirname: str) -> List[str]:
    exp_dir = root / "peptide"
    if not exp_dir.exists():
        raise FileNotFoundError(f"Expected {exp_dir} to exist")

    methods = []
    for mdir in sorted(exp_dir.iterdir()):
        if not mdir.is_dir():
            continue
        seed_dir = mdir / run_id / seed_dirname
        if (seed_dir / "config.json").exists() and (seed_dir / "checkpoints").exists():
            methods.append(mdir.name)
    return methods


def load_cfg(seed_dir: Path) -> dict:
    cfg_path = seed_dir / "config.json"
    if not cfg_path.exists():
        raise FileNotFoundError(f"Missing config.json: {cfg_path}")
    return read_json(cfg_path)


def find_final_ckpt(seed_dir: Path, epochs: int) -> Path:
    ckpt_dir = seed_dir / "checkpoints"
    target = ckpt_dir / f"epoch_{epochs:06d}.pt"
    if target.exists():
        return target

    pts = list(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        raise FileNotFoundError(f"No checkpoints found under {ckpt_dir}")

    def epoch_from_name(p: Path) -> int:
        m = re.search(r"epoch_(\d+)\.pt", p.name)
        return int(m.group(1)) if m else -1

    pts.sort(key=epoch_from_name)
    return pts[-1]


# -----------------------
# Cache helpers (agora em seed_dir/samples/)
# -----------------------
def load_ragged_seqs_npz(path: Path) -> List[Tuple[int, ...]]:
    data = np.load(path, allow_pickle=False)
    tokens = data["tokens"].astype(np.int16, copy=False)
    offsets = data["offsets"].astype(np.int32, copy=False)
    out: List[Tuple[int, ...]] = []
    for i in range(len(offsets) - 1):
        a = int(offsets[i])
        b = int(offsets[i + 1])
        out.append(tuple(int(x) for x in tokens[a:b]))
    return out


def cache_filename(mode: str, n_samples: int, thr: float, tag: str = "") -> str:
    suffix = f"_{tag}" if tag else ""
    return f"good_samples_{mode}_n{n_samples}_thr{thr}{suffix}.npz"


def resolve_cache_path(seed_dir: Path, mode: str, n_samples: int, thr: float, tag: str) -> Path:
    fname = cache_filename(mode, n_samples, thr, tag)
    p_new = seed_dir / "samples" / fname
    if p_new.exists():
        return p_new
    p_old = seed_dir / "analysis" / fname
    return p_old


# -----------------------
# Model / env helpers
# -----------------------
def pick_forward_state_dict(payload: dict) -> Dict[str, torch.Tensor]:
    banned = {"teacher_fnet", "div_fnet"}
    preferred = [
        "fnet",
        "student_fnet",
        "st_fnet",
        "forward_fnet",
        "pf",
        "pf_net",
        "policy",
        "model",
        "sa_fnet",
    ]
    for k in preferred:
        if k in banned:
            continue
        v = payload.get(k, None)
        if isinstance(v, dict):
            return v

    for k, v in payload.items():
        if k in banned:
            continue
        if isinstance(v, dict) and len(v) > 0:
            any_key = next(iter(v.keys()))
            if isinstance(any_key, str) and ("weight" in any_key or "bias" in any_key or "." in any_key):
                return v

    raise KeyError(f"Could not find allowed forward policy weights in ckpt keys: {list(payload.keys())}")


def build_policy_from_cfg(cfg: dict, device: str) -> Policy:
    kw = {}
    for k in ["emb_dim", "hidden", "pos_dim", "window"]:
        if k in cfg and cfg[k] is not None:
            try:
                kw[k] = int(cfg[k])
            except Exception:
                pass
    try:
        return Policy(**kw).to(device)
    except TypeError:
        return Policy().to(device)


def make_env(cfg: dict, seq_size: int, batch_size: int, seed: int) -> Sequences:
    cutoff = float(cfg.get("cut_off", cfg.get("cutoff", 0.94)))
    eps = float(cfg.get("eps", 0.0))
    log_reward = LogReward(cutoff=cutoff)
    return Sequences(
        seq_size=seq_size,
        batch_size=batch_size,
        log_reward=log_reward,
        eps=eps,
        seed=seed,
    )


@torch.no_grad()
def forward_sampling(env: Sequences, forward_net: torch.nn.Module):
    for i in range(env.seq_size):
        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break
        s_sub = env.state.index_select(0, active)
        logits = forward_net(s_sub)
        actions = env.get_actions(logits, training=False)
        env.state[active, i] = actions
        env.alive[active] = (actions != 0)


def state_row_to_tuple(row: torch.Tensor) -> Tuple[int, ...]:
    arr = row.tolist()
    out = []
    for t in arr:
        t = int(t)
        if t == 0:
            break
        out.append(t)
    return tuple(out)


# -----------------------
# k-mer features
# -----------------------
def kmer_features(
    seq: Tuple[int, ...],
    *,
    vocab_size: int = 20,
    k: int = 2,
    include_unigram: bool = True,
    include_length: bool = True,
    length_norm: float = 1.0,
) -> np.ndarray:
    L = len(seq)
    feats: List[np.ndarray] = []

    if include_unigram:
        u = np.zeros((vocab_size,), dtype=np.float32)
        if L > 0:
            for a in seq:
                if 1 <= a <= vocab_size:
                    u[a - 1] += 1.0
            u /= float(L)
        feats.append(u)

    if k == 2:
        km = np.zeros((vocab_size * vocab_size,), dtype=np.float32)
        denom = max(1, L - 1)
        if L >= 2:
            for i in range(L - 1):
                a, b = seq[i], seq[i + 1]
                if 1 <= a <= vocab_size and 1 <= b <= vocab_size:
                    idx = (a - 1) * vocab_size + (b - 1)
                    km[idx] += 1.0
        km /= float(denom)
        feats.append(km)
    elif k == 1:
        pass
    else:
        raise ValueError("Only k=1 or k=2 supported.")

    if include_length:
        feats.append(np.array([float(L) / float(length_norm)], dtype=np.float32))

    return np.concatenate(feats, axis=0)


# -----------------------
# RF group coloring helpers (argmax sobre 5 modelos)
# -----------------------
_AA20 = "ACDEFGHIKLMNPQRSTVWY"


def ids_to_aa_string(seq: Tuple[int, ...], vocab_size: int = 20) -> str:
    out = []
    for t in seq:
        if t == 0:
            break
        if 1 <= t <= vocab_size and vocab_size == 20:
            out.append(_AA20[t - 1])
        else:
            out.append("X")
    return "".join(out)


class RFBundle:
    def __init__(self, models_dir: Path):
        if joblib is None:
            raise RuntimeError("joblib não está instalado, mas --color_by_rf_group foi usado.")

        self.models_dir = models_dir
        files = [
            "rf_ecoli.joblib",
            "rf_saureus.joblib",
            "rf_paeruginosa.joblib",
            "rf_bsubtilis.joblib",
            "rf_calbicans.joblib",
        ]
        paths = [models_dir / f for f in files]
        missing = [str(p) for p in paths if not p.exists()]
        if missing:
            raise FileNotFoundError(
                "Não encontrei os modelos RF.\n"
                f"  MODELS_DIR = {models_dir}\n"
                f"  Faltando: {missing}\n"
                "Passe --rf_models_dir (caminho para rf_models/models) ou defina RF_MODELS_DIR."
            )

        self.bundles = [joblib.load(str(p)) for p in paths]  # type: ignore
        self.vocab = self.bundles[0]["vocab"]
        self.max_len = int(self.bundles[0]["max_len"])
        assert all(b["vocab"] == self.vocab and int(b["max_len"]) == self.max_len for b in self.bundles), \
            "RF models disagree on vocab/max_len."

        self.tok = {t: i for i, t in enumerate(self.vocab)}
        assert "*" in self.tok, "RF vocab must contain '*' (EOS)."
        self.eos_idx = int(self.bundles[0].get("eos_index", self.tok["*"]))
        assert self.eos_idx == 0, f"Expected EOS at index 0, got {self.eos_idx}."
        self.V = len(self.vocab)

        self._identity = (
            self.V == 21
            and self.vocab[0] == "*"
            and list(self.vocab[1:21]) == list(_AA20)
        )
        self._aa_to_rf = {aa: int(self.tok.get(aa, self.eos_idx)) for aa in _AA20}

    def _map_ids_to_rf_ids(self, batch_ids: torch.Tensor) -> torch.Tensor:
        if batch_ids.ndim != 2:
            raise ValueError(f"batch_ids deve ter shape [B, L], recebido {tuple(batch_ids.shape)}")
        B, L = batch_ids.shape
        Luse = min(L, self.max_len)
        x = batch_ids[:, :Luse].to(torch.long).clone()

        if self._identity:
            x = torch.where((x >= 0) & (x < self.V), x, torch.full_like(x, self.eos_idx))
        else:
            mapped = torch.full_like(x, self.eos_idx)
            mapped = torch.where(x == 0, torch.full_like(mapped, self.eos_idx), mapped)
            for i, aa in enumerate(_AA20, start=1):
                rf_i = self._aa_to_rf[aa]
                mapped = torch.where(x == i, torch.full_like(mapped, rf_i), mapped)
            x = mapped

        if Luse < self.max_len:
            pad = torch.full((B, self.max_len - Luse), self.eos_idx, dtype=torch.long)
            x = torch.cat([x, pad], dim=1)
        return x

    def _encode_onehot_np(self, batch_ids_rf: torch.Tensor) -> np.ndarray:
        B, Lmax = batch_ids_rf.shape
        if Lmax != self.max_len:
            raise ValueError(f"Expected L={self.max_len}, got {Lmax}")

        arr = batch_ids_rf.detach().cpu().numpy().astype(np.int64, copy=False)
        X = np.zeros((B, self.max_len, self.V), dtype=np.float32)
        X[:, :, self.eos_idx] = 1.0
        for b in range(B):
            ids_b = arr[b]
            valid = (ids_b >= 0) & (ids_b < self.V)
            rows = np.nonzero(valid)[0]
            if rows.size > 0:
                cols = ids_b[rows]
                X[b, rows, :] = 0.0
                X[b, rows, cols] = 1.0
        return X.reshape(B, -1)

    def predict_proba_all(self, batch_ids: torch.Tensor, batch_size: int = 4096) -> np.ndarray:
        if batch_ids.ndim != 2:
            raise ValueError(f"batch_ids must be [N, L], got {tuple(batch_ids.shape)}")

        N = batch_ids.shape[0]
        out = np.zeros((N, 5), dtype=np.float32)

        for a in range(0, N, batch_size):
            b = min(N, a + batch_size)
            x = batch_ids[a:b]
            x_rf = self._map_ids_to_rf_ids(x)
            Xi = self._encode_onehot_np(x_rf)
            preds = [bundle["model"].predict_proba(Xi)[:, 1] for bundle in self.bundles]
            out[a:b] = np.stack(preds, axis=1).astype(np.float32, copy=False)

        return out

    def argmax_group(self, batch_ids: torch.Tensor, *, min_p: float = 0.0, batch_size: int = 4096) -> np.ndarray:
        P = self.predict_proba_all(batch_ids, batch_size=batch_size)
        g = P.argmax(axis=1).astype(np.int32, copy=False)
        if min_p > 0:
            m = P.max(axis=1) < float(min_p)
            g = g.astype(np.int32, copy=True)
            g[m] = -1
        return g


# -----------------------
# Collect features per method
# -----------------------
def features_from_cached_samples(
    seed_dir: Path,
    *,
    cfg: dict,
    mode: str,
    cache_n_samples: int,
    cache_thr: float,
    cache_tag: str,
    k: int,
    include_unigram: bool,
    include_length: bool,
    vocab_size: int,
) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
    seq_size = int(cfg["seq_size"])
    length_norm = float(seq_size)

    cache_path = resolve_cache_path(seed_dir, mode, cache_n_samples, cache_thr, cache_tag)
    if not cache_path.exists():
        return np.zeros((0, 1), dtype=np.float32), []

    seqs = load_ragged_seqs_npz(cache_path)
    if len(seqs) == 0:
        return np.zeros((0, 1), dtype=np.float32), []

    feats = [
        kmer_features(
            s,
            vocab_size=vocab_size,
            k=k,
            include_unigram=include_unigram,
            include_length=include_length,
            length_norm=length_norm,
        )
        for s in seqs
    ]
    X = np.stack(feats, axis=0).astype(np.float32, copy=False)
    return X, seqs


def collect_unique_good_sequences_by_sampling_final_ckpt(
    seed_dir: Path,
    ckpt_path: Path,
    *,
    n_samples: int,
    batch_size: int,
    min_logr: float,
    device: str,
    k: int,
    include_unigram: bool,
    include_length: bool,
    vocab_size: int,
    progress: bool,
) -> Tuple[np.ndarray, List[Tuple[int, ...]]]:
    cfg = load_cfg(seed_dir)
    seq_size = int(cfg["seq_size"])
    seed = int(cfg.get("seed", 0))
    length_norm = float(seq_size)

    payload = torch.load(ckpt_path, map_location="cpu")
    net = build_policy_from_cfg(cfg, device=device)
    net.load_state_dict(pick_forward_state_dict(payload))
    net.eval()

    env = make_env(cfg, seq_size=seq_size, batch_size=batch_size, seed=seed)

    unique: set[Tuple[int, ...]] = set()
    seqs: List[Tuple[int, ...]] = []
    feats: List[np.ndarray] = []

    n_batches = (n_samples + batch_size - 1) // batch_size
    it = range(n_batches)
    if progress and tqdm is not None:
        it = tqdm(it, desc=f"{seed_dir.parent.parent.name}", leave=False)

    with torch.no_grad():
        for _ in it:
            env.reset()
            forward_sampling(env, net)
            logR = env.log_reward()
            good_idx = (logR >= min_logr).nonzero(as_tuple=True)[0]
            if good_idx.numel() == 0:
                continue
            states_good = env.state.index_select(0, good_idx)
            for j in range(states_good.shape[0]):
                tup = state_row_to_tuple(states_good[j])
                if not tup:
                    continue
                if tup in unique:
                    continue
                unique.add(tup)
                seqs.append(tup)
                feats.append(
                    kmer_features(
                        tup,
                        vocab_size=vocab_size,
                        k=k,
                        include_unigram=include_unigram,
                        include_length=include_length,
                        length_norm=length_norm,
                    )
                )

    if len(feats) == 0:
        return np.zeros((0, 1), dtype=np.float32), []

    X = np.stack(feats, axis=0).astype(np.float32, copy=False)
    return X, seqs


# -----------------------
# Dim reduction
# -----------------------
def reduce_to_2d(
    X: np.ndarray,
    *,
    reducer: str,
    pca_pre: int,
    seed: int,
    umap_n_neighbors: int,
    umap_min_dist: float,
) -> np.ndarray:
    if X.shape[0] == 0:
        return np.zeros((0, 2), dtype=np.float32)

    if reducer == "pca":
        if PCA is None:
            Xc = X - X.mean(axis=0, keepdims=True)
            _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
            Z = Xc @ Vt[:2].T
            return Z.astype(np.float32, copy=False)
        pca = PCA(n_components=2, random_state=seed, svd_solver="randomized")
        return pca.fit_transform(X).astype(np.float32, copy=False)

    if reducer == "umap":
        if umap is None:
            raise RuntimeError("UMAP reducer requested but umap-learn is not installed.")
        X_in = X
        if pca_pre and pca_pre > 0 and PCA is not None:
            pca = PCA(
                n_components=min(pca_pre, X.shape[1]),
                random_state=seed,
                svd_solver="randomized",
            )
            X_in = pca.fit_transform(X).astype(np.float32, copy=False)

        um = umap.UMAP(
            n_components=2,
            n_neighbors=int(umap_n_neighbors),
            min_dist=float(umap_min_dist),
            metric="euclidean",
            random_state=int(seed),
        )
        return um.fit_transform(X_in).astype(np.float32, copy=False)

    raise ValueError(f"Unknown reducer: {reducer}")

def save_legend(handles, labels, out_path: str | Path, *, ncol=2, fontsize=10,
                handlelength=2.2, columnspacing=1.4):
    out_path = Path(out_path)
    fig_leg = plt.figure(figsize=(6.2, 0.55))
    fig_leg.legend(
        handles, labels,
        loc="center",
        ncol=int(ncol),
        frameon=False,
        fontsize=fontsize,
        handlelength=handlelength,
        columnspacing=columnspacing,
    )
    fig_leg.canvas.draw()
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig_leg.savefig(out_path, bbox_inches="tight", pad_inches=0.0, dpi=300)
    plt.close(fig_leg)
    
# -----------------------
# Representatives (on a single method's embedding)
# -----------------------
def robust_axis_limits(Z: np.ndarray, quantile_clip: float = 0.99) -> Tuple[float, float, float, float]:
    x = Z[:, 0]
    y = Z[:, 1]
    q = float(quantile_clip)
    q = min(max(q, 0.5), 1.0)

    xmin, xmax = np.quantile(x, [1 - q, q])
    ymin, ymax = np.quantile(y, [1 - q, q])

    if not np.isfinite([xmin, xmax, ymin, ymax]).all() or xmin == xmax or ymin == ymax:
        xmin, xmax = float(x.min()), float(x.max())
        ymin, ymax = float(y.min()), float(y.max())

    pad_x = 0.02 * (xmax - xmin + 1e-9)
    pad_y = 0.02 * (ymax - ymin + 1e-9)
    return (xmin - pad_x, xmax + pad_x, ymin - pad_y, ymax + pad_y)


def select_reps_kcenter(Z: np.ndarray, K: int, seed: int = 0) -> np.ndarray:
    N = Z.shape[0]
    if N == 0:
        return np.array([], dtype=int)
    K = int(min(K, N))
    Zf = Z.astype(np.float32, copy=False)

    mu = Zf.mean(axis=0, keepdims=True)
    d2 = np.sum((Zf - mu) ** 2, axis=1)
    first = int(np.argmax(d2))

    chosen = np.empty((K,), dtype=np.int32)
    chosen[0] = first

    min_d2 = np.sum((Zf - Zf[first:first + 1]) ** 2, axis=1)

    rng = np.random.default_rng(seed)
    for t in range(1, K):
        m = float(min_d2.max())
        cand = np.where(np.isclose(min_d2, m, rtol=1e-6, atol=1e-8))[0]
        nxt = int(rng.choice(cand)) if cand.size > 1 else int(np.argmax(min_d2))
        chosen[t] = nxt

        d2_new = np.sum((Zf - Zf[nxt:nxt + 1]) ** 2, axis=1)
        min_d2 = np.minimum(min_d2, d2_new)

    return np.array(sorted(set(int(i) for i in chosen.tolist())), dtype=int)


def select_reps_kmeans(Z: np.ndarray, K: int, seed: int = 0) -> np.ndarray:
    if KMeans is None:
        raise RuntimeError("KMeans requested but scikit-learn not installed. Use --focus_rep_selector kcenter.")
    N = Z.shape[0]
    if N == 0:
        return np.array([], dtype=int)
    K = int(min(K, N))
    km = KMeans(n_clusters=K, n_init=10, random_state=seed)
    labels = km.fit_predict(Z)
    centers = km.cluster_centers_.astype(np.float32, copy=False)

    reps: List[int] = []
    for c in range(K):
        idx = np.where(labels == c)[0]
        if idx.size == 0:
            continue
        pts = Z[idx].astype(np.float32, copy=False)
        d2 = np.sum((pts - centers[c:c + 1]) ** 2, axis=1)
        reps.append(int(idx[int(np.argmin(d2))]))
    return np.array(sorted(set(reps)), dtype=int)


# -----------------------
# RDKit drawing
# -----------------------
def rdkit_image_from_sequence(aa_seq: str, size: int = 220) -> Optional[np.ndarray]:
    if not _HAS_RDKIT:
        return None
    if aa_seq is None or len(aa_seq) == 0:
        return None
    try:
        mol = Chem.MolFromSequence(aa_seq)  # type: ignore
        if mol is None:
            return None
        img = Draw.MolToImage(mol, size=(size, size))
        return np.asarray(img)
    except Exception:
        return None


def render_rep_to_ax(
    ax: plt.Axes,
    seq: Tuple[int, ...],
    *,
    vocab_size: int,
    img_size: int,
    fallback_text: bool,
    text_trunc: int,
    text_fontsize: int,
    label: str,
):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)

    aa = ids_to_aa_string(seq, vocab_size=vocab_size)
    img = rdkit_image_from_sequence(aa, size=img_size)

    ax.text(
        0.02, 0.98, label,
        transform=ax.transAxes,
        ha="left", va="top",
        fontsize=11, fontweight="bold",
    )

    if img is not None:
        ax.imshow(img)
        return

    if fallback_text:
        txt = aa
        if len(txt) > text_trunc:
            txt = txt[:text_trunc] + "…"
        ax.text(0.5, 0.5, txt, transform=ax.transAxes, ha="center", va="center", fontsize=text_fontsize)


# -----------------------
# Plot helper: draw RF groups with zorder inverso ao count
# -----------------------
def scatter_by_group_with_inverse_count_zorder(
    ax: plt.Axes,
    Zxy: np.ndarray,
    g: np.ndarray,
    *,
    point_size: float,
    point_alpha: float,
    colors: List[str],
    unknown_color: str = "black",
    unknown_zorder: int = 999,
):
    """
    Desenha grupos do MAIOR -> MENOR (maior atrás, menor na frente).
    g: array de ints com {0..G-1} e opcional -1 (unknown).
    """
    g = g.astype(np.int32, copy=False)
    G = len(colors)

    # counts only for valid groups
    counts = np.bincount(g[g >= 0], minlength=G) if np.any(g >= 0) else np.zeros((G,), dtype=np.int64)

    # order groups by decreasing count (big first => behind)
    order = np.argsort(-counts)  # descending

    base_z = 1
    for rank, gi in enumerate(order.tolist()):
        m = (g == gi)
        if not np.any(m):
            continue
        ax.scatter(
            Zxy[m, 0], Zxy[m, 1],
            s=float(point_size),
            alpha=float(point_alpha),
            rasterized=True,
            color=colors[gi],
            zorder=base_z + rank,  # small rank => behind
        )

    # unknown on top if present
    m_u = (g < 0)
    if np.any(m_u):
        ax.scatter(
            Zxy[m_u, 0], Zxy[m_u, 1],
            s=float(point_size),
            alpha=float(point_alpha),
            rasterized=True,
            color=unknown_color,
            zorder=int(unknown_zorder),
        )


# -----------------------
# Focus mosaic figure
# -----------------------
def make_focus_mosaic_figure(
    *,
    Z: np.ndarray,
    labels: np.ndarray,
    seqs: List[Tuple[int, ...]],
    focus_method: str,
    out_path: Path,
    xylim_quantile: float,
    rep_k: int,
    rep_selector: str,
    reducer_name: str,
    run_id: str,
    seed_dirname: str,
    img_size: int,
    fallback_text: bool,
    text_trunc: int,
    text_fontsize: int,
    reducer_seed: int,
    vocab_size: int,
    max_points: int,
    point_alpha: float,
    point_size: float,
    group_idx_all: Optional[np.ndarray],
    group_colors: Optional[List[str]],
):
    idx_focus = np.where(labels == focus_method)[0]
    if idx_focus.size == 0:
        raise ValueError(f"Focus method not present: {focus_method}")

    Zf = Z[idx_focus]
    seqs_f = [seqs[i] for i in idx_focus.tolist()]

    xmin, xmax, ymin, ymax = robust_axis_limits(Zf, quantile_clip=float(xylim_quantile))

    if rep_selector == "kmeans":
        reps_local = select_reps_kmeans(Zf, K=int(rep_k), seed=int(reducer_seed))
    else:
        reps_local = select_reps_kcenter(Zf, K=int(rep_k), seed=int(reducer_seed))

    if reps_local.size == 0:
        raise ValueError("No representatives found for focus method (empty set).")

    idx_plot = np.arange(Zf.shape[0])
    if max_points and max_points > 0 and idx_plot.size > max_points:
        rng = np.random.default_rng(0)
        idx_plot = rng.choice(idx_plot, size=int(max_points), replace=False)

    K = int(reps_local.size)
    nrows_rep = int(math.ceil(K / 2))

    fig = plt.figure(figsize=(12.5, max(6, 2.0 * nrows_rep)))

    gs = GridSpec(
        nrows=nrows_rep,
        ncols=3,
        width_ratios=[3.3, 1.2, 1.2],
        wspace=0.0,
        hspace=0.0,
        figure=fig,
    )

    ax_scatter = fig.add_subplot(gs[:, 0])

    if group_idx_all is None or group_colors is None:
        ax_scatter.scatter(
            Zf[idx_plot, 0], Zf[idx_plot, 1],
            s=2*float(point_size),
            alpha=float(point_alpha),
            rasterized=True,
            zorder=1,
        )
    else:
        g_focus = group_idx_all[idx_focus]   # [N_focus]
        g_plot = g_focus[idx_plot]           # [N_plot]
        Zplot = Zf[idx_plot]
        scatter_by_group_with_inverse_count_zorder(
            ax_scatter,
            Zplot,
            g_plot,
            point_size=2*float(point_size),
            point_alpha=float(point_alpha),
            colors=group_colors,
        )

    # highlight reps (always on top)
    rep_xy = Zf[reps_local]
    # ax_scatter.scatter(
    #     rep_xy[:, 0], rep_xy[:, 1],
    #     s=55.0,
    #     alpha=0.95,
    #     edgecolors="black",
    #     linewidths=0.7,
    #     color="tab:orange",
    #     zorder=1000,
    # )

    for t, j in enumerate(reps_local.tolist(), start=1):
        x, y = float(Zf[j, 0]), float(Zf[j, 1])
        ax_scatter.text(
            x, y, str(t),
            color="#040404",
            fontsize=12,
            #fontweight="bold",
            ha="center",
            va="center",
            zorder=1001,
            bbox=dict(boxstyle="round", facecolor="white", edgecolor=None, alpha=0.45),
        )

    ax_scatter.set_xlim(xmin, xmax)
    ax_scatter.set_ylim(ymin, ymax)
    ax_scatter.set_title("Efficient peptides on k-mer embedding space")
    apply_format(ax_scatter)

    for r, j in enumerate(reps_local.tolist(), start=1):
        rr = (r - 1) // 2          # linha
        cc = 1 + ((r - 1) % 2)     # coluna 1 ou 2
        ax_rep = fig.add_subplot(gs[rr, cc])
        render_rep_to_ax(
            ax_rep,
            seqs_f[j],
            vocab_size=int(vocab_size),
            img_size=int(img_size),
            fallback_text=bool(fallback_text),
            text_trunc=int(text_trunc),
            text_fontsize=int(text_fontsize),
            label=str(r),
        )

    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=220, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved focus mosaic: {out_path}")


# -----------------------
# Args / main
# -----------------------
def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--run_id", required=True, type=str)
    p.add_argument("--seed", required=True, type=str, help='e.g. "10" or "seed_10"')
    p.add_argument("--root", default="runs", type=str)

    p.add_argument("--n_samples", default=1_000, type=int)
    p.add_argument("--batch_size", default=4096, type=int)
    p.add_argument("--min_logr", default=0.0, type=float)
    p.add_argument("--device", default="cpu", type=str)

    p.add_argument("--k", default=2, type=int, choices=[1, 2])
    p.add_argument("--no_unigram", action="store_true")
    p.add_argument("--no_length", action="store_true")
    p.add_argument("--vocab_size", default=20, type=int)

    p.add_argument("--use_cache", action="store_true")
    p.add_argument("--cache_mode", default="final", choices=["final", "cumulative"])
    p.add_argument("--cache_n_samples", default=50_000, type=int)
    p.add_argument("--cache_thr", default=0.0, type=float)
    p.add_argument("--cache_tag", default="", type=str)

    p.add_argument("--reducer", default="pca", choices=["pca", "umap"])
    p.add_argument("--pca_pre", default=50, type=int)
    p.add_argument("--umap_n_neighbors", default=30, type=int)
    p.add_argument("--umap_min_dist", default=0.1, type=float)
    p.add_argument("--reducer_seed", default=0, type=int)

    p.add_argument("--max_points", default=0, type=int, help="0 => plot all; else subsample per method (plot-only)")
    p.add_argument("--alpha", default=0.15, type=float)
    p.add_argument("--s", default=6.0, type=float)
    p.add_argument("--subplots", action="store_true", help="one subplot per method + 1 reps subplot")
    p.add_argument("--xylim_quantile", default=0.99, type=float)

    p.add_argument("--focus_method", default=None, type=str)
    p.add_argument("--focus_rep_k", default=8, type=int)
    p.add_argument("--focus_rep_selector", default="kcenter", choices=["kcenter", "kmeans"])
    p.add_argument("--focus_out", default=None, type=str)

    p.add_argument("--rep_img_size", default=220, type=int)
    p.add_argument("--rep_fallback_text", action="store_true")
    p.add_argument("--rep_text_trunc", default=18, type=int)
    p.add_argument("--rep_text_fontsize", default=8, type=int)

    p.add_argument("--color_by_rf_group", action="store_true")
    p.add_argument("--rf_models_dir", default=None, type=str)
    p.add_argument("--rf_min_p", default=0.0, type=float)
    p.add_argument("--rf_legend", action="store_true")
    p.add_argument("--rf_batch", default=4096, type=int)

    p.add_argument("--out", default=None, type=str)
    p.add_argument("--save_npz", action="store_true")
    p.add_argument("--no_progress", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()

    root = Path(args.root)
    seed_dirname = parse_seed_arg(args.seed)
    methods = list_methods(root, args.run_id, seed_dirname)
    if not methods:
        raise SystemExit(f"No methods found under {root}/peptide/*/{args.run_id}/{seed_dirname}")

    progress = not args.no_progress
    include_unigram = not args.no_unigram
    include_length = not args.no_length

    rf_bundle: Optional[RFBundle] = None
    if args.color_by_rf_group:
        env_dir = os.environ.get("RF_MODELS_DIR")
        if args.rf_models_dir is not None:
            models_dir = Path(args.rf_models_dir)
        elif env_dir:
            models_dir = Path(env_dir)
        else:
            here = Path(__file__).resolve()
            repo_root = here.parents[0]
            models_dir = repo_root / "rf_models" / "models"
        rf_bundle = RFBundle(models_dir=models_dir)

    all_X: List[np.ndarray] = []
    all_labels: List[str] = []
    all_seqs: List[Tuple[int, ...]] = []
    per_method_counts: Dict[str, int] = {}

    method_iter = methods
    if progress and tqdm is not None:
        method_iter = tqdm(methods, desc="methods", leave=True)

    for method in method_iter:
        seed_dir = root / "peptide" / method / args.run_id / seed_dirname
        cfg = load_cfg(seed_dir)

        X_m: np.ndarray
        seqs_m: List[Tuple[int, ...]]

        if args.use_cache:
            X_m, seqs_m = features_from_cached_samples(
                seed_dir,
                cfg=cfg,
                mode=args.cache_mode,
                cache_n_samples=int(args.cache_n_samples),
                cache_thr=float(args.cache_thr),
                cache_tag=str(args.cache_tag),
                k=int(args.k),
                include_unigram=include_unigram,
                include_length=include_length,
                vocab_size=int(args.vocab_size),
            )
        else:
            X_m, seqs_m = np.zeros((0, 1), np.float32), []

        if X_m.shape[0] == 0:
            epochs = int(cfg["epochs"])
            ckpt_path = find_final_ckpt(seed_dir, epochs=epochs)
            X_m, seqs_m = collect_unique_good_sequences_by_sampling_final_ckpt(
                seed_dir,
                ckpt_path,
                n_samples=int(args.n_samples),
                batch_size=int(args.batch_size),
                min_logr=float(args.min_logr),
                device=str(args.device),
                k=int(args.k),
                include_unigram=include_unigram,
                include_length=include_length,
                vocab_size=int(args.vocab_size),
                progress=progress,
            )

        per_method_counts[method] = int(X_m.shape[0])
        if X_m.shape[0] == 0:
            continue

        all_X.append(X_m)
        all_labels.extend([method] * X_m.shape[0])
        all_seqs.extend(seqs_m)

    if not all_X:
        raise SystemExit("No samples found for any method (cache empty and fallback sampling produced none).")

    X = np.concatenate(all_X, axis=0).astype(np.float32, copy=False)
    labels = np.array(all_labels, dtype=str)

    Z = reduce_to_2d(
        X,
        reducer=args.reducer,
        pca_pre=int(args.pca_pre),
        seed=int(args.reducer_seed),
        umap_n_neighbors=int(args.umap_n_neighbors),
        umap_min_dist=float(args.umap_min_dist),
    )

    uniq_methods = [m for m in methods if m in set(labels.tolist())]

    group_idx_all: Optional[np.ndarray] = None
    if rf_bundle is not None:
        N = len(all_seqs)
        maxL = max((len(s) for s in all_seqs), default=1)
        maxL = int(max(maxL, 1))
        batch_ids = torch.zeros((N, maxL), dtype=torch.long)
        for i, s in enumerate(all_seqs):
            if len(s) > 0:
                batch_ids[i, :len(s)] = torch.tensor(s, dtype=torch.long)
        group_idx_all = rf_bundle.argmax_group(
            batch_ids,
            min_p=float(args.rf_min_p),
            batch_size=int(args.rf_batch),
        )

    if args.subplots:
        n = len(uniq_methods)
        fig, axes = plt.subplots(1, n, figsize=(4.2 * n, 4.2), sharex=True, sharey=True)
        if n == 1:
            axes = [axes]  # type: ignore

        xmin, xmax, ymin, ymax = robust_axis_limits(Z, quantile_clip=float(args.xylim_quantile))

        for ax, method in zip(axes, uniq_methods):
            idx = np.where(labels == method)[0]
            if idx.size == 0:
                ax.set_title(f"{method} (0)")
                ax.set_xlim(xmin, xmax)
                ax.set_ylim(ymin, ymax)
                continue

            if args.max_points and idx.size > int(args.max_points):
                rng = np.random.default_rng(0)
                idx = rng.choice(idx, size=int(args.max_points), replace=False)

            if group_idx_all is None:
                ax.scatter(
                    Z[idx, 0], Z[idx, 1],
                    s=float(args.s),
                    alpha=float(args.alpha),
                    rasterized=True,
                    zorder=1,
                )
            else:
                g = group_idx_all[idx]
                Zxy = Z[idx]
                scatter_by_group_with_inverse_count_zorder(
                    ax,
                    Zxy,
                    g,
                    point_size=float(args.s),
                    point_alpha=float(args.alpha),
                    colors=RF_GROUP_COLORS,
                )

            method_title = field_to_name.get(method, method)
            ax.set_title(rf"{method_title} [{per_method_counts.get(method, 0)}]")
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(ymin, ymax)
            apply_format(ax)

            
            handles = [
                Line2D([0], [0], marker="o", linestyle="None", markersize=5,
                    color=RF_GROUP_COLORS[i], label=RF_GROUP_NAMES[i])
                for i in range(len(RF_GROUP_NAMES))
            ]
            labels_leg = RF_GROUP_NAMES

        plt.tight_layout()

        out_dir = root / "peptide" / "plots"
        out_dir.mkdir(parents=True, exist_ok=True)
        mode = f"cache_{args.cache_mode}" if args.use_cache else "sample_final"
        tag = "rfgroup" if args.color_by_rf_group else "plain"
        out_path = out_dir / f"embed_subplots_k{args.k}_{args.reducer}_{mode}_{tag}_{args.run_id}_{seed_dirname}.png"
        plt.savefig(out_path, dpi=220)
        plt.close()
        
        leg_path = out_path.with_name(out_path.stem + "_legend.png")
        save_legend(handles, labels_leg, leg_path, ncol=5, fontsize=9)
        print(f"Saved comparative plot: {out_path}")

    if args.focus_method is not None:
        if args.focus_out is None:
            out_dir2 = root / "peptide" / "plots"
            out_dir2.mkdir(parents=True, exist_ok=True)
            out2 = out_dir2 / f"focus_mosaic_{args.focus_method}_{args.run_id}_{seed_dirname}_{args.reducer}.png"
        else:
            out2 = Path(args.focus_out)

        make_focus_mosaic_figure(
            Z=Z,
            labels=labels,
            seqs=all_seqs,
            focus_method=str(args.focus_method),
            out_path=out2,
            xylim_quantile=float(args.xylim_quantile),
            rep_k=int(args.focus_rep_k),
            rep_selector=str(args.focus_rep_selector),
            reducer_name=str(args.reducer),
            run_id=str(args.run_id),
            seed_dirname=str(seed_dirname),
            img_size=int(args.rep_img_size),
            fallback_text=bool(args.rep_fallback_text),
            text_trunc=int(args.rep_text_trunc),
            text_fontsize=int(args.rep_text_fontsize),
            reducer_seed=int(args.reducer_seed),
            vocab_size=int(args.vocab_size),
            max_points=int(args.max_points),
            point_alpha=float(args.alpha),
            point_size=float(args.s),
            group_idx_all=group_idx_all,
            group_colors=RF_GROUP_COLORS if group_idx_all is not None else None,
        )

    if args.save_npz:
        out_dir = root / "peptide" / "plots"
        out_dir.mkdir(parents=True, exist_ok=True)
        npz_path = out_dir / f"embed_data_k{args.k}_{args.reducer}_{args.run_id}_{seed_dirname}.npz"
        np.savez_compressed(
            npz_path,
            Z=Z.astype(np.float32),
            labels=labels.astype(str),
            X=X.astype(np.float32),
            seqs=np.array(all_seqs, dtype=object),
            group_idx=group_idx_all.astype(np.int32) if group_idx_all is not None else None,
        )
        print(f"Saved data: {npz_path}")

    print("\nUnique good sequences per method:")
    for m in methods:
        print(f"  - {m:20s}: {per_method_counts.get(m, 0)}")


if __name__ == "__main__":
    main()