"""Optional learned verifier baseline."""

from __future__ import annotations

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

from .utils import CATEGORY_NAMES, unit_compatible


def run_learned_baseline(problems: pd.DataFrame, candidates: pd.DataFrame, seed: int) -> pd.DataFrame:
    """Train and evaluate a random forest baseline split by problem_id."""

    merged = candidates.merge(
        problems[["problem_id", "ground_truth_value", "ground_truth_unit"]],
        on="problem_id",
        how="left",
        validate="many_to_one",
    )
    features = _make_features(merged)
    labels = merged["is_valid"].astype(bool).to_numpy()

    problem_ids = np.array(sorted(merged["problem_id"].unique()))
    if len(problem_ids) < 4:
        return pd.DataFrame()
    train_ids, test_ids = train_test_split(problem_ids, test_size=0.30, random_state=seed, shuffle=True)
    train_mask = merged["problem_id"].isin(train_ids).to_numpy()
    test_mask = merged["problem_id"].isin(test_ids).to_numpy()

    model = RandomForestClassifier(
        n_estimators=200,
        max_depth=8,
        min_samples_leaf=2,
        class_weight="balanced",
        random_state=seed,
        n_jobs=1,
    )
    model.fit(features.loc[train_mask], labels[train_mask])
    predictions = model.predict(features.loc[test_mask])

    test_rows = merged.loc[test_mask].copy().reset_index(drop=True)
    return pd.DataFrame(
        {
            "problem_id": test_rows["problem_id"],
            "candidate_id": test_rows["candidate_id"],
            "category": test_rows["category"],
            "candidate_type": test_rows["candidate_type"],
            "is_valid": test_rows["is_valid"].astype(bool),
            "method": "learned_baseline",
            "split": "test",
            "prediction": predictions.astype(bool),
            "supported": True,
        }
    )


def _make_features(merged: pd.DataFrame) -> pd.DataFrame:
    truth_scale = np.maximum(np.abs(merged["ground_truth_value"].astype(float).to_numpy()), 1.0)
    raw_error = np.abs(merged["candidate_value"].astype(float).to_numpy() - merged["ground_truth_value"].astype(float).to_numpy())
    relative_error = np.minimum(raw_error / truth_scale, 1e6)

    feature_frame = pd.DataFrame(
        {
            "absolute_relative_numeric_error": relative_error,
            "unit_compatible": [
                unit_compatible(candidate, truth)
                for candidate, truth in zip(merged["candidate_unit"], merged["ground_truth_unit"])
            ],
            "candidate_unit_equals_ground_truth_unit": merged["candidate_unit"].astype(str).to_numpy()
            == merged["ground_truth_unit"].astype(str).to_numpy(),
            "candidate_value_finite": np.isfinite(merged["candidate_value"].astype(float).to_numpy()),
        }
    )
    for category in CATEGORY_NAMES:
        feature_frame[f"category={category}"] = merged["category"].eq(category).to_numpy()
    return feature_frame.astype(float)
