"""Train option B-hybrid: TF-IDF (1-3 grams) + BGE-small embedding + LogReg.

Concatenates the v1 TF-IDF feature space with the v2 384-dim BGE embedding,
plus the same hand-crafted scalars. Logistic regression head.
"""
from __future__ import annotations

import argparse
import json
import time
from dataclasses import dataclass
from pathlib import Path

import joblib
import numpy as np
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from sklearn.preprocessing import StandardScaler

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.embed_features import (
    DEFAULT_ENCODER,
    encode_rows,
    _get_encoder,
)
from analysis.exploration.llm_validation.classifier.features import (
    hand_crafted_row,
    text_for_row,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver,
    report_split,
    split_by_trace,
)


@dataclass
class HybridFeaturePipeline:
    encoder_name: str
    vectorizer: TfidfVectorizer
    scaler: StandardScaler

    def transform(self, rows: list[dict]):
        encoder = _get_encoder(self.encoder_name)
        emb = encode_rows(rows, encoder=encoder, show_progress=False)
        texts = [text_for_row(r) for r in rows]
        X_text = self.vectorizer.transform(texts)
        hand = np.array([hand_crafted_row(r) for r in rows], dtype=np.float32)
        hand_scaled = self.scaler.transform(hand)
        return sp.hstack([
            X_text, sp.csr_matrix(emb), sp.csr_matrix(hand_scaled),
        ], format="csr")


def labels_to_y(rows, label_order):
    idx = {l: i for i, l in enumerate(label_order)}
    return np.array([idx[r["llm_label"]] for r in rows], dtype=np.int32)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silver", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--reuse-splits", type=Path, required=True)
    ap.add_argument("--encoder", default=DEFAULT_ENCODER)
    args = ap.parse_args()

    rows = load_silver(args.silver)
    splits = json.load(open(args.reuse_splits))
    by_id = {r["span_id"]: r for r in rows}
    train = [by_id[i] for i in splits["train_ids"] if i in by_id]
    val   = [by_id[i] for i in splits["val_ids"]   if i in by_id]
    test  = [by_id[i] for i in splits["test_ids"]  if i in by_id]
    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)

    label_order = list(PRIMITIVES)
    y_train = labels_to_y(train, label_order)
    y_val   = labels_to_y(val,   label_order)
    y_test  = labels_to_y(test,  label_order)

    # ---- Encode (CUDA) ----
    print()
    print(f"Encoding with {args.encoder}...")
    encoder = _get_encoder(args.encoder)
    emb_train = encode_rows(train, encoder=encoder, batch_size=128)
    emb_val = encode_rows(val, encoder=encoder, batch_size=128, show_progress=False)
    emb_test = encode_rows(test, encoder=encoder, batch_size=128, show_progress=False)
    print(f"  shapes: train={emb_train.shape}, val={emb_val.shape}, test={emb_test.shape}")

    # ---- Hand-crafted scaled ----
    hand_train = np.array([hand_crafted_row(r) for r in train], dtype=np.float32)
    scaler = StandardScaler()
    hand_train_s = scaler.fit_transform(hand_train)
    hand_val_s = scaler.transform(np.array([hand_crafted_row(r) for r in val], dtype=np.float32))
    hand_test_s = scaler.transform(np.array([hand_crafted_row(r) for r in test], dtype=np.float32))

    # ---- TF-IDF ----
    vec = TfidfVectorizer(
        ngram_range=(1, 3), max_features=50_000, sublinear_tf=True,
        lowercase=True, min_df=2, norm="l2",
    )
    X_train_text = vec.fit_transform([text_for_row(r) for r in train])
    X_val_text = vec.transform([text_for_row(r) for r in val])
    X_test_text = vec.transform([text_for_row(r) for r in test])

    X_train = sp.hstack([
        X_train_text, sp.csr_matrix(emb_train), sp.csr_matrix(hand_train_s),
    ], format="csr")
    X_val = sp.hstack([
        X_val_text, sp.csr_matrix(emb_val), sp.csr_matrix(hand_val_s),
    ], format="csr")
    X_test = sp.hstack([
        X_test_text, sp.csr_matrix(emb_test), sp.csr_matrix(hand_test_s),
    ], format="csr")
    print(f"  hybrid X_train shape: {X_train.shape}")

    # ---- Grid search ----
    print()
    print("=== Grid search on val ===")
    best = None
    for C in [0.1, 1.0, 10.0, 100.0]:
        t0 = time.time()
        clf = LogisticRegression(
            class_weight="balanced", max_iter=2000, C=C, n_jobs=-1, solver="lbfgs",
        )
        clf.fit(X_train, y_train)
        macro = f1_score(y_val, clf.predict(X_val), average="macro", zero_division=0)
        print(f"  C={C:>6}: val_macro_f1={macro:.4f}  ({time.time()-t0:.1f}s)")
        if best is None or macro > best["macro"]:
            best = {"clf": clf, "C": C, "macro": macro}

    test_macro = f1_score(y_test, best["clf"].predict(X_test), average="macro", zero_division=0)
    print()
    print(f"Best: C={best['C']} val_macro_f1={best['macro']:.4f} test_macro_f1={test_macro:.4f}")

    pipe = HybridFeaturePipeline(
        encoder_name=args.encoder, vectorizer=vec, scaler=scaler,
    )
    bundle = {
        "feature_pipeline": pipe,
        "classifier": best["clf"],
        "label_order": label_order,
        "hyperparams": {"C": best["C"]},
        "pipeline_type": "hybrid",
    }
    args.out.parent.mkdir(parents=True, exist_ok=True)
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")
    import shutil
    shutil.copy(args.reuse_splits, args.out.with_suffix(".splits.json"))


if __name__ == "__main__":
    main()
