"""Verification methods for candidate biomedical math solutions."""

from __future__ import annotations

import json
import math
from typing import Any

import numpy as np
import pandas as pd

from .utils import (
    UREG,
    convert_value,
    make_quantity,
    numeric_close,
    parse_unit,
    unit_compatible,
)


DETERMINISTIC_METHODS = ["answer_only", "unit_only", "numeric_plus_unit", "step_aware"]


def run_all_verifiers(problems: pd.DataFrame, candidates: pd.DataFrame) -> pd.DataFrame:
    """Run deterministic verifiers and return one prediction row per method."""

    merged = candidates.merge(
        problems[["problem_id", "ground_truth_value", "ground_truth_unit"]],
        on="problem_id",
        how="left",
        validate="many_to_one",
    )

    rows: list[dict[str, object]] = []
    for record in merged.to_dict(orient="records"):
        predictions = {
            "answer_only": answer_only(record),
            "unit_only": unit_only(record),
            "numeric_plus_unit": numeric_plus_unit(record),
            "step_aware": step_aware(record),
        }
        for method, prediction in predictions.items():
            supported = not _is_nan_prediction(prediction)
            rows.append(
                {
                    "problem_id": record["problem_id"],
                    "candidate_id": record["candidate_id"],
                    "category": record["category"],
                    "candidate_type": record["candidate_type"],
                    "is_valid": bool(record["is_valid"]),
                    "method": method,
                    "split": "all",
                    "prediction": prediction if supported else np.nan,
                    "supported": supported,
                }
            )
    return pd.DataFrame(rows)


def answer_only(row: dict[str, Any]) -> bool:
    """Accept if the raw numeric answer matches, ignoring units."""

    return numeric_close(row["candidate_value"], row["ground_truth_value"])


def unit_only(row: dict[str, Any]) -> bool:
    """Accept if candidate and ground-truth units have compatible dimensions."""

    return unit_compatible(row["candidate_unit"], row["ground_truth_unit"])


def numeric_plus_unit(row: dict[str, Any]) -> bool:
    """Accept only if units are compatible and converted numeric values match."""

    if not unit_compatible(row["candidate_unit"], row["ground_truth_unit"]):
        return False
    converted = convert_value(row["candidate_value"], row["candidate_unit"], row["ground_truth_unit"])
    if converted is None:
        return False
    return numeric_close(converted, row["ground_truth_value"])


def step_aware(row: dict[str, Any]) -> bool | float:
    """Check structured arithmetic steps and the final unit-aware answer."""

    raw_steps = row.get("structured_steps_json")
    if raw_steps is None or (isinstance(raw_steps, float) and math.isnan(raw_steps)) or not str(raw_steps).strip():
        return np.nan
    try:
        steps = json.loads(str(raw_steps))
    except json.JSONDecodeError:
        return False
    if not isinstance(steps, list) or not steps:
        return np.nan

    for step in steps:
        if not _check_step(step):
            return False
    return numeric_plus_unit(row)


def _check_step(step: dict[str, Any]) -> bool:
    op = step.get("op")
    if op in {"add", "subtract", "multiply", "divide"}:
        return _check_binary_step(step, op)
    if op == "power":
        return _check_power_step(step)
    if op == "convert":
        return _check_convert_step(step)
    return False


def _check_binary_step(step: dict[str, Any], op: str) -> bool:
    left = make_quantity(step.get("left_value"), step.get("left_unit"))
    right = make_quantity(step.get("right_value"), step.get("right_unit"))
    result_unit = parse_unit(step.get("result_unit"))
    if left is None or right is None or result_unit is None:
        return False

    try:
        if op == "add":
            if not unit_compatible(step.get("left_unit"), step.get("right_unit")):
                return False
            expected = left + right.to(left.units)
        elif op == "subtract":
            if not unit_compatible(step.get("left_unit"), step.get("right_unit")):
                return False
            expected = left - right.to(left.units)
        elif op == "multiply":
            expected = left * right
        else:
            expected = left / right
        converted = expected.to(result_unit)
    except Exception:
        return False

    return numeric_close(float(converted.magnitude), float(step.get("result_value")))


def _check_power_step(step: dict[str, Any]) -> bool:
    base = make_quantity(step.get("base_value"), step.get("base_unit"))
    exponent = make_quantity(step.get("exponent_value"), step.get("exponent_unit"))
    result = make_quantity(step.get("result_value"), step.get("result_unit"))
    if base is None or exponent is None or result is None:
        return False
    if not _is_dimensionless(base) or not _is_dimensionless(exponent) or not _is_dimensionless(result):
        return False
    try:
        expected = float(base.to(UREG.dimensionless).magnitude) ** float(
            exponent.to(UREG.dimensionless).magnitude
        )
        observed = float(result.to(UREG.dimensionless).magnitude)
    except Exception:
        return False
    return numeric_close(observed, expected)


def _check_convert_step(step: dict[str, Any]) -> bool:
    converted = convert_value(step.get("input_value"), step.get("input_unit"), step.get("result_unit"))
    if converted is None:
        return False
    return numeric_close(converted, float(step.get("result_value")))


def _is_dimensionless(quantity: Any) -> bool:
    return quantity.dimensionality == UREG.dimensionless.dimensionality


def _is_nan_prediction(value: object) -> bool:
    return isinstance(value, float) and math.isnan(value)
