"""Shared utilities for BioDimBench generation, verification, and reporting."""

from __future__ import annotations

import json
import logging
import math
from functools import lru_cache
from pathlib import Path
from typing import Any

import numpy as np
import pint


CATEGORY_NAMES = [
    "Dosage",
    "Dilution",
    "Cell count",
    "Half-life decay",
    "Exponential growth",
    "Flow rate",
    "Imaging scale area",
    "Bioink weight/volume percent",
    "Molarity",
    "Unit conversion",
]

CANDIDATE_TYPES = [
    "correct",
    "wrong_arithmetic",
    "wrong_formula",
    "wrong_unit",
    "missing_conversion",
    "plausible_scalar_wrong_unit",
]

INVALID_CANDIDATE_TYPES = [
    "wrong_arithmetic",
    "wrong_formula",
    "wrong_unit",
    "missing_conversion",
    "plausible_scalar_wrong_unit",
]

METHOD_ORDER = [
    "answer_only",
    "unit_only",
    "numeric_plus_unit",
    "step_aware",
    "learned_baseline",
]

METHOD_DISPLAY = {
    "answer_only": "Answer only",
    "unit_only": "Unit only",
    "numeric_plus_unit": "Numeric + unit",
    "step_aware": "Step-aware",
    "learned_baseline": "Learned baseline",
}

ERROR_TYPE_DISPLAY = {
    "wrong_arithmetic": "Arithmetic",
    "wrong_formula": "Formula",
    "wrong_unit": "Unit",
    "missing_conversion": "Conversion",
    "plausible_scalar_wrong_unit": "Plausible scalar",
}


def build_unit_registry() -> pint.UnitRegistry:
    """Create a pint registry with biomedical count-like units."""

    registry = pint.UnitRegistry(autoconvert_offset_to_baseunit=True)
    definitions = [
        "cell = [cell]",
        "cells = cell",
        "kilocell = 1000 * cell = kcell = kcells",
        "pixel = 1 = px",
        "pixels = pixel",
    ]
    for definition in definitions:
        try:
            registry.define(definition)
        except pint.errors.DefinitionSyntaxError:
            logging.warning("Could not define pint unit: %s", definition)
        except ValueError:
            # Unit was already present. This is harmless when modules reload.
            pass
    return registry


UREG = build_unit_registry()

UNIT_ALIASES = {
    "": "",
    "1": "",
    "dimensionless": "",
    "kg": "kilogram",
    "g": "gram",
    "mg": "milligram",
    "ug": "microgram",
    "µg": "microgram",
    "L": "liter",
    "l": "liter",
    "mL": "milliliter",
    "ml": "milliliter",
    "hr": "hour",
    "hrs": "hour",
    "hour": "hour",
    "hours": "hour",
    "min": "minute",
    "mins": "minute",
    "minute": "minute",
    "minutes": "minute",
    "mg/kg": "milligram / kilogram",
    "mg/L": "milligram / liter",
    "g/L": "gram / liter",
    "mM": "millimole / liter",
    "M": "mole / liter",
    "mmol": "millimole",
    "mol": "mole",
    "mm": "millimeter",
    "cm": "centimeter",
    "mm^2": "millimeter ** 2",
    "mm2": "millimeter ** 2",
    "cm^2": "centimeter ** 2",
    "cm2": "centimeter ** 2",
    "cells": "cell",
    "cell": "cell",
    "kcell": "kilocell",
    "kcells": "kilocell",
    "cells/mL": "cell / milliliter",
    "cell/mL": "cell / milliliter",
    "g/mL": "gram / milliliter",
    "mL/hr": "milliliter / hour",
    "mL/hour": "milliliter / hour",
    "mm/pixel": "millimeter / pixel",
    "mm/px": "millimeter / pixel",
    "pixel": "pixel",
    "pixels": "pixel",
    "px": "pixel",
}


def ensure_output_dirs(root: Path) -> dict[str, Path]:
    """Create output directories and return their paths."""

    paths = {
        "data": root / "outputs" / "data",
        "metrics": root / "outputs" / "metrics",
        "figures": root / "outputs" / "figures",
        "latex": root / "outputs" / "latex",
    }
    for path in paths.values():
        path.mkdir(parents=True, exist_ok=True)
    return paths


def to_json(data: dict[str, Any] | list[dict[str, Any]]) -> str:
    """Serialize structured fields deterministically."""

    return json.dumps(data, sort_keys=True, separators=(",", ":"))


def from_json(value: str | float | None, default: Any) -> Any:
    """Parse a JSON field, returning a default for blanks or invalid values."""

    if value is None:
        return default
    if isinstance(value, float) and math.isnan(value):
        return default
    if not str(value).strip():
        return default
    try:
        return json.loads(value)
    except json.JSONDecodeError:
        logging.warning("Could not parse JSON field: %s", value)
        return default


def format_number(value: float, sig: int = 6) -> str:
    """Format a number compactly for problem and solution text."""

    if not np.isfinite(value):
        return str(value)
    if value == 0:
        return "0"
    abs_value = abs(value)
    if 1e-3 <= abs_value < 1e6:
        return f"{value:.{sig}g}"
    return f"{value:.{sig}e}"


def normalize_unit(unit: str | float | None) -> str:
    """Map benchmark unit strings into pint-compatible unit expressions."""

    if unit is None:
        return ""
    if isinstance(unit, float) and math.isnan(unit):
        return ""
    text = str(unit).strip()
    return UNIT_ALIASES.get(text, text)


@lru_cache(maxsize=256)
def parse_unit(unit: str | float | None) -> pint.Unit | None:
    """Parse a unit string, logging and returning None on failure."""

    normalized = normalize_unit(unit)
    if normalized == "":
        return UREG.dimensionless
    try:
        return UREG.parse_units(normalized)
    except Exception as exc:  # pint exposes several parse-time exceptions.
        logging.warning("Could not parse unit '%s' as '%s': %s", unit, normalized, exc)
        return None


def make_quantity(value: float, unit: str | float | None) -> pint.Quantity | None:
    """Build a pint quantity, returning None when the unit is invalid."""

    parsed = parse_unit(unit)
    if parsed is None:
        return None
    try:
        return float(value) * parsed
    except Exception as exc:
        logging.warning("Could not create quantity for %s %s: %s", value, unit, exc)
        return None


def unit_compatible(unit_a: str | float | None, unit_b: str | float | None) -> bool:
    """Return True when two units share the same dimensionality."""

    parsed_a = parse_unit(unit_a)
    parsed_b = parse_unit(unit_b)
    if parsed_a is None or parsed_b is None:
        return False
    return parsed_a.dimensionality == parsed_b.dimensionality


def convert_value(value: float, from_unit: str, to_unit: str) -> float | None:
    """Convert a numeric value between compatible units."""

    quantity = make_quantity(value, from_unit)
    target = parse_unit(to_unit)
    if quantity is None or target is None:
        return None
    try:
        converted = quantity.to(target)
        return float(converted.magnitude)
    except Exception as exc:
        logging.debug("Could not convert %s %s to %s: %s", value, from_unit, to_unit, exc)
        return None


def numeric_close(candidate: float, truth: float, rtol: float = 1e-5, atol: float = 1e-12) -> bool:
    """Compare numeric values with stable relative tolerance."""

    try:
        candidate_float = float(candidate)
        truth_float = float(truth)
    except (TypeError, ValueError):
        return False
    if not np.isfinite(candidate_float) or not np.isfinite(truth_float):
        return False
    return bool(abs(candidate_float - truth_float) <= atol + rtol * max(abs(truth_float), 1e-12))


def safe_divide(numerator: float, denominator: float, fallback: float = np.nan) -> float:
    """Divide two numbers with a NaN fallback for zero denominators."""

    if denominator == 0:
        return fallback
    return numerator / denominator


def method_label(method: str, split: str = "all") -> str:
    """Display label for plots and reports."""

    label = METHOD_DISPLAY.get(method, method)
    if split != "all":
        return f"{label} ({split})"
    return label


def percent(value: float) -> str:
    """Format a fraction as a percentage string."""

    if value is None or not np.isfinite(value):
        return "NA"
    return f"{100 * value:.1f}%"
