import numpy as np, torch
import math
import os
from pathlib import Path
import joblib


_env_dir = os.environ.get("RF_MODELS_DIR")


_HERE = Path(__file__).resolve()
_REPO_ROOT = _HERE.parents[0]  # suba 1 nível a partir de simple_pep/
_DEFAULT_MODELS_DIR = _REPO_ROOT / "rf_models" / "models"

MODELS_DIR = Path(_env_dir) if _env_dir else _DEFAULT_MODELS_DIR

_BUNDLE_FILES = [
    "rf_ecoli.joblib",
    "rf_saureus.joblib",
    "rf_paeruginosa.joblib",
    "rf_bsubtilis.joblib",
    "rf_calbicans.joblib",
]
_BUNDLE_PATHS = [MODELS_DIR / f for f in _BUNDLE_FILES]

_missing = [str(p) for p in _BUNDLE_PATHS if not p.exists()]
if _missing:
    raise FileNotFoundError(
        "Não encontrei os modelos RF.\n"
        f"  MODELS_DIR = {MODELS_DIR}\n"
        f"  Faltando: {_missing}\n\n"
        "Defina RF_MODELS_DIR (caminho ABSOLUTO para 'rf_models/models') "
        "ou ajuste _REPO_ROOT conforme sua árvore de diretórios."
    )

BUNDLES = [joblib.load(str(p)) for p in _BUNDLE_PATHS]

VOCAB = BUNDLES[0]["vocab"]; MAX_LEN = BUNDLES[0]["max_len"]
assert all(b["vocab"] == VOCAB and b["max_len"] == MAX_LEN for b in BUNDLES), "Models disagree on vocab/max_len."
TOK = {t: i for i, t in enumerate(VOCAB)}

assert "*" in TOK, "Vocab must contain '*' (EOS)."
EOS_IDX = BUNDLES[0].get("eos_index", TOK["*"])

assert EOS_IDX == 0, f"Expected EOS at index 0, got {EOS_IDX}. Re-train models or check vocab."

V = len(VOCAB)

def _encode_onehot_np(batch_ids: torch.Tensor, lengths: torch.Tensor = None) -> np.ndarray:
    if batch_ids.ndim != 2:
        raise ValueError(f"batch_ids deve ter shape [B, Lmax], recebido {tuple(batch_ids.shape)}")
    B, Lmax = batch_ids.shape
    batch_ids = batch_ids[:, :MAX_LEN].to(torch.long).detach().cpu().numpy()
    L = batch_ids.shape[1]

    X = np.zeros((B, MAX_LEN, V), dtype=np.float32)
    X[:, :, EOS_IDX] = 1.0  # padding por EOS
    if lengths is None:
        eff_L = np.full((B,), L, dtype=np.int64)
    else:
        eff_L = np.minimum(lengths.to(torch.long).cpu().numpy(), L)
    for b in range(B):
        lb = int(eff_L[b])
        if lb <= 0:
            continue
        ids_b = batch_ids[b, :lb]
        valid = (ids_b >= 0) & (ids_b < V)
        rows = np.nonzero(valid)[0]
        if rows.size > 0:
            cols = ids_b[rows]
            X[b, rows, :] = 0.0
            X[b, rows, cols] = 1.0
    return X.reshape(B, -1)


def rf_proba(batch_ids):
    Xi = _encode_onehot_np(batch_ids)  # -> [B, MAX_LEN*V] (np.float32)
    preds = [b["model"].predict_proba(Xi)[:, 1] for b in BUNDLES]  # lista de [B]
    return np.stack(preds, axis=1).max(axis=1)


class LogReward(torch.nn.Module):
    def __init__(self, cutoff=0.95, t=0.3, device: str = "cpu", **kwargs):
        super().__init__()
        self.cutoff = cutoff
        self.T = t
        self.device = torch.device(device)

    def forward(self, seq_batch):# seq_batch: list[str]
        p = torch.tensor(rf_proba(seq_batch), dtype=torch.float32)
        sizes = (seq_batch != 0).sum(1) + 1
        logit_p = p.log() - (1-p).log()
        logit_cutoff = math.log(self.cutoff) - math.log(1 - self.cutoff)
        logR = (logit_p - logit_cutoff)/self.T
        logR = torch.where(logR < 0, logR * sizes, logR)
        return logR.clip(-30, 0)


if __name__ == "__main__":
    print(rf_proba(torch.tensor([[1, 2]])))
    print(rf_proba(torch.tensor([[3, 5]])))
    print(rf_proba(torch.tensor([[1, 2], [3, 5]])))
    #print("GLPRKILCAIAKKKGKCKGPLKLVCKC:", rf_proba("GLPRKILCAIAKKKGKCKGPLKLVCKC"))
    # print("GLRKRLRKFRNKIKEKLKKIGQKIQGFVPKLAPRTDY:", rf_proba("GLRKRLRKFRNKIKEKLKKIGQKIQGFVPKLAPRTDY"))
    # print("VOCAB:", VOCAB)
    # print((encode_onehot_batch_np_from_ids(torch.tensor([[0, 1, 1], [0, 1, 0]])) == _encode_onehot_np("*AA")).all(1))