"""Sentence-embedding feature pipeline (option B).

Encoder: BAAI/bge-small-en-v1.5 (384-dim, supports CUDA).
Features: 384-dim embedding of `preceding_context + " <SEP> " + span_text`
        + the same 7 hand-crafted scalars from the TF-IDF pipeline.

Implements the same .transform(rows) -> 2D numpy array contract as
FeaturePipeline so evaluate.py can use it without branching.
"""
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from sklearn.preprocessing import StandardScaler

from analysis.exploration.llm_validation.classifier.features import (
    HAND_CRAFTED_NAMES,
    hand_crafted_row,
    text_for_row,
)


DEFAULT_ENCODER = "BAAI/bge-small-en-v1.5"


def _get_encoder(name: str = DEFAULT_ENCODER, device: str | None = None):
    from sentence_transformers import SentenceTransformer
    if device is None:
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
    return SentenceTransformer(name, device=device)


def encode_rows(
    rows: list[dict], encoder=None,
    batch_size: int = 64,
    show_progress: bool = True,
) -> np.ndarray:
    """Return dense (N, embedding_dim) array."""
    if encoder is None:
        encoder = _get_encoder()
    texts = [text_for_row(r) for r in rows]
    return encoder.encode(
        texts,
        batch_size=batch_size,
        show_progress_bar=show_progress,
        normalize_embeddings=True,
        convert_to_numpy=True,
    )


@dataclass
class EmbedFeaturePipeline:
    encoder_name: str
    scaler: StandardScaler
    embedding_dim: int

    def transform(self, rows: list[dict]) -> np.ndarray:
        encoder = _get_encoder(self.encoder_name)
        emb = encode_rows(rows, encoder=encoder, show_progress=False)
        hand = np.array([hand_crafted_row(r) for r in rows], dtype=np.float32)
        hand_scaled = self.scaler.transform(hand)
        return np.hstack([emb, hand_scaled]).astype(np.float32)


def fit_embed_pipeline(
    rows: list[dict], encoder_name: str = DEFAULT_ENCODER,
) -> tuple[EmbedFeaturePipeline, np.ndarray]:
    encoder = _get_encoder(encoder_name)
    emb = encode_rows(rows, encoder=encoder)
    hand = np.array([hand_crafted_row(r) for r in rows], dtype=np.float32)
    scaler = StandardScaler()
    hand_scaled = scaler.fit_transform(hand)
    pipe = EmbedFeaturePipeline(
        encoder_name=encoder_name,
        scaler=scaler,
        embedding_dim=emb.shape[1],
    )
    X = np.hstack([emb, hand_scaled]).astype(np.float32)
    return pipe, X
