"""Features for the span classifier: TF-IDF + hand-crafted scalars."""
from __future__ import annotations

import math
from dataclasses import dataclass

import numpy as np
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler


SEP = " <SEP> "
HAND_CRAFTED_NAMES = [
    "n_tokens_z",
    "log_n_tokens",
    "position_in_trace",
    "is_first_span",
    "is_last_span",
    "math_density",
    "correct_int",
]


def text_for_row(row: dict) -> str:
    """Concatenate preceding context and span_text for vectorization."""
    pc = (row.get("preceding_context") or "").strip()
    sp_ = (row.get("span_text") or "").strip()
    return pc + SEP + sp_


def math_density(text: str) -> float:
    if not text:
        return 0.0
    n_math = sum(1 for c in text if c in "0123456789+-*/=()")
    return n_math / max(1, len(text))


def hand_crafted_row(row: dict) -> list[float]:
    n_tokens = row.get("n_tokens") or 1
    n_eps = row.get("n_episodes_in_trace") or 1
    ep_idx = row.get("episode_idx") or 0
    pos = ep_idx / max(1, n_eps - 1) if n_eps > 1 else 0.0
    return [
        float(n_tokens),  # raw — scaled by StandardScaler later
        math.log1p(n_tokens),
        float(pos),
        float(ep_idx == 0),
        float(ep_idx == n_eps - 1),
        math_density(row.get("span_text") or ""),
        float(bool(row.get("correct"))),
    ]


@dataclass
class FeaturePipeline:
    """Bundle of fitted vectorizer + scaler. Pickled with the model."""

    vectorizer: TfidfVectorizer
    scaler: StandardScaler

    def transform(self, rows: list[dict]) -> sp.csr_matrix:
        texts = [text_for_row(r) for r in rows]
        X_text = self.vectorizer.transform(texts)
        X_hand = np.array([hand_crafted_row(r) for r in rows], dtype=np.float32)
        X_hand_scaled = self.scaler.transform(X_hand)
        return sp.hstack([X_text, sp.csr_matrix(X_hand_scaled)], format="csr")


def fit_feature_pipeline(
    rows: list[dict],
    ngram_range: tuple[int, int] = (1, 3),
    max_features: int = 50_000,
    min_df: int = 2,
) -> tuple[FeaturePipeline, sp.csr_matrix]:
    """Fit TF-IDF + scaler on a list of training rows; return pipeline + X."""
    texts = [text_for_row(r) for r in rows]
    vec = TfidfVectorizer(
        ngram_range=ngram_range,
        max_features=max_features,
        sublinear_tf=True,
        lowercase=True,
        min_df=min_df,
        norm="l2",
    )
    X_text = vec.fit_transform(texts)

    X_hand = np.array([hand_crafted_row(r) for r in rows], dtype=np.float32)
    scaler = StandardScaler()
    X_hand_scaled = scaler.fit_transform(X_hand)

    X = sp.hstack([X_text, sp.csr_matrix(X_hand_scaled)], format="csr")
    return FeaturePipeline(vectorizer=vec, scaler=scaler), X
