"""Synthetic biomedical math benchmark generation."""

from __future__ import annotations

import numpy as np
import pandas as pd

from .utils import CATEGORY_NAMES, format_number, to_json


def generate_benchmark(n: int, seed: int) -> pd.DataFrame:
    """Generate approximately balanced BioDimBench base problems."""

    rng = np.random.default_rng(seed)
    counts = np.full(len(CATEGORY_NAMES), n // len(CATEGORY_NAMES), dtype=int)
    counts[: n % len(CATEGORY_NAMES)] += 1

    rows: list[dict[str, object]] = []
    problem_index = 0
    for category, count in zip(CATEGORY_NAMES, counts):
        for _ in range(int(count)):
            problem_index += 1
            rows.append(_generate_problem(category, rng, problem_index))

    return pd.DataFrame(rows)


def _generate_problem(category: str, rng: np.random.Generator, problem_index: int) -> dict[str, object]:
    problem_id = f"P{problem_index:05d}"
    if category == "Dosage":
        weight_kg = float(np.round(rng.uniform(40, 120), 1))
        dose_mg_per_kg = float(np.round(rng.uniform(0.5, 15), 2))
        truth = weight_kg * dose_mg_per_kg
        params = {"weight_kg": weight_kg, "dose_mg_per_kg": dose_mg_per_kg}
        text = (
            f"A patient weighs {format_number(weight_kg)} kg and receives "
            f"{format_number(dose_mg_per_kg)} mg/kg. What total dose is required?"
        )
        solution = (
            f"Total dose = {format_number(weight_kg)} kg x "
            f"{format_number(dose_mg_per_kg)} mg/kg = {format_number(truth)} mg."
        )
        return _row(problem_id, category, text, truth, "mg", "total_mg=W*D", params, solution)

    if category == "Dilution":
        stock = float(rng.choice([25, 50, 100, 200, 500]))
        target = float(np.round(rng.uniform(0.5, min(20, stock * 0.4)), 2))
        final_volume = float(np.round(rng.uniform(2, 250), 1))
        truth = target * final_volume / stock
        params = {
            "stock_concentration_mM": stock,
            "target_concentration_mM": target,
            "final_volume_mL": final_volume,
        }
        text = (
            f"Prepare {format_number(final_volume)} mL of {format_number(target)} mM solution "
            f"from a {format_number(stock)} mM stock. What volume of stock is needed?"
        )
        solution = (
            f"V1 = C2 V2 / C1 = {format_number(target)} x {format_number(final_volume)} / "
            f"{format_number(stock)} = {format_number(truth)} mL."
        )
        return _row(problem_id, category, text, truth, "mL", "V1=C2*V2/C1", params, solution)

    if category == "Cell count":
        density = float(rng.choice([1e5, 2e5, 5e5, 1e6, 2e6, 5e6]))
        volume = float(np.round(rng.uniform(0.1, 20), 2))
        truth = density * volume
        params = {"density_cells_per_mL": density, "volume_mL": volume}
        text = (
            f"A suspension has {format_number(density)} cells/mL and volume "
            f"{format_number(volume)} mL. How many cells are present?"
        )
        solution = (
            f"Cells = {format_number(density)} cells/mL x {format_number(volume)} mL = "
            f"{format_number(truth)} cells."
        )
        return _row(problem_id, category, text, truth, "cells", "cells=rho*V", params, solution)

    if category == "Half-life decay":
        c0 = float(np.round(rng.uniform(5, 100), 2))
        half_life = float(np.round(rng.uniform(1, 24), 2))
        elapsed = float(np.round(rng.uniform(0.5, 72), 2))
        truth = c0 * (0.5 ** (elapsed / half_life))
        params = {"initial_concentration_mg_per_L": c0, "half_life_hr": half_life, "elapsed_hr": elapsed}
        text = (
            f"Initial concentration is {format_number(c0)} mg/L, half-life is "
            f"{format_number(half_life)} hours, and elapsed time is {format_number(elapsed)} hours. "
            "What concentration remains?"
        )
        solution = (
            f"C = C0 x 0.5^(t/T) = {format_number(c0)} x 0.5^"
            f"({format_number(elapsed)}/{format_number(half_life)}) = {format_number(truth)} mg/L."
        )
        return _row(problem_id, category, text, truth, "mg/L", "C=C0*0.5^(t/T)", params, solution)

    if category == "Exponential growth":
        n0 = float(rng.choice([1e4, 2e4, 5e4, 1e5, 2e5, 5e5, 1e6]))
        doubling_time = float(np.round(rng.uniform(8, 48), 2))
        elapsed = float(np.round(rng.uniform(4, 96), 2))
        truth = n0 * (2 ** (elapsed / doubling_time))
        params = {"initial_cells": n0, "doubling_time_hr": doubling_time, "elapsed_hr": elapsed}
        text = (
            f"Initial culture size is {format_number(n0)} cells, doubling time is "
            f"{format_number(doubling_time)} hours, and elapsed time is {format_number(elapsed)} hours. "
            "What final number of cells is expected?"
        )
        solution = (
            f"N = N0 x 2^(t/Td) = {format_number(n0)} x 2^"
            f"({format_number(elapsed)}/{format_number(doubling_time)}) = {format_number(truth)} cells."
        )
        return _row(problem_id, category, text, truth, "cells", "N=N0*2^(t/Td)", params, solution)

    if category == "Flow rate":
        flow = float(np.round(rng.uniform(0.1, 50), 2))
        elapsed = float(np.round(rng.uniform(0.25, 24), 2))
        truth = flow * elapsed
        params = {"flow_mL_per_hr": flow, "elapsed_hr": elapsed}
        text = (
            f"A pump runs at {format_number(flow)} mL/hr for {format_number(elapsed)} hours. "
            "What total volume is delivered?"
        )
        solution = (
            f"Volume = {format_number(flow)} mL/hr x {format_number(elapsed)} hr = "
            f"{format_number(truth)} mL."
        )
        return _row(problem_id, category, text, truth, "mL", "V=Q*t", params, solution)

    if category == "Imaging scale area":
        width_px = int(rng.integers(50, 2001))
        height_px = int(rng.integers(50, 2001))
        pixel_size = float(np.round(rng.uniform(0.001, 0.1), 4))
        truth = width_px * height_px * (pixel_size**2)
        params = {"width_px": width_px, "height_px": height_px, "pixel_size_mm_per_pixel": pixel_size}
        text = (
            f"An object spans width {width_px} pixels and height {height_px} pixels. "
            f"Pixel size is {format_number(pixel_size)} mm/pixel. What is the area?"
        )
        solution = (
            f"Area = {width_px} x {height_px} x {format_number(pixel_size)}^2 = "
            f"{format_number(truth)} mm^2."
        )
        return _row(problem_id, category, text, truth, "mm^2", "area=w*h*s^2", params, solution)

    if category == "Bioink weight/volume percent":
        volume = float(np.round(rng.uniform(1, 100), 1))
        percent = float(np.round(rng.uniform(0.5, 20), 2))
        truth = (percent / 100) * volume
        params = {"volume_mL": volume, "percent_wv": percent}
        text = (
            f"Prepare {format_number(volume)} mL of {format_number(percent)} percent w/v hydrogel. "
            "How many grams solute are needed? Assume percent w/v means grams per 100 mL."
        )
        solution = (
            f"Mass = ({format_number(percent)}/100) x {format_number(volume)} = "
            f"{format_number(truth)} g."
        )
        return _row(problem_id, category, text, truth, "g", "grams=(p/100)*V", params, solution)

    if category == "Molarity":
        volume = float(np.round(rng.uniform(0.1, 100), 2))
        concentration = float(np.round(rng.uniform(0.1, 1000), 2))
        truth = concentration * volume / 1000
        params = {"volume_mL": volume, "concentration_mM": concentration}
        text = (
            f"How many millimoles are in {format_number(volume)} mL of a "
            f"{format_number(concentration)} mM solution?"
        )
        solution = (
            f"mmol = C x V / 1000 = {format_number(concentration)} x "
            f"{format_number(volume)} / 1000 = {format_number(truth)} mmol."
        )
        return _row(problem_id, category, text, truth, "mmol", "mmol=C*V/1000", params, solution)

    return _generate_unit_conversion(problem_id, rng)


def _generate_unit_conversion(problem_id: str, rng: np.random.Generator) -> dict[str, object]:
    case = str(rng.choice(["mg_to_g", "mL_to_L", "ug_to_mg", "hours_to_minutes"]))
    if case == "mg_to_g":
        source_value = float(rng.choice([50, 100, 250, 500, 750, 1000, 1250, 2000, 5000]))
        truth = source_value / 1000
        source_unit, target_unit = "mg", "g"
    elif case == "mL_to_L":
        source_value = float(rng.choice([10, 25, 50, 100, 250, 500, 750, 1000, 2000]))
        truth = source_value / 1000
        source_unit, target_unit = "mL", "L"
    elif case == "ug_to_mg":
        source_value = float(rng.choice([25, 50, 100, 250, 500, 1000, 2500, 5000]))
        truth = source_value / 1000
        source_unit, target_unit = "ug", "mg"
    else:
        source_value = float(rng.choice([0.25, 0.5, 1, 1.5, 2, 3, 4, 6, 12, 24]))
        truth = source_value * 60
        source_unit, target_unit = "hr", "min"

    params = {
        "conversion_case": case,
        "source_value": source_value,
        "source_unit": source_unit,
        "target_unit": target_unit,
    }
    text = f"Convert {format_number(source_value)} {source_unit} to {target_unit}."
    solution = f"{format_number(source_value)} {source_unit} = {format_number(truth)} {target_unit}."
    return _row(
        problem_id,
        "Unit conversion",
        text,
        truth,
        target_unit,
        case,
        params,
        solution,
    )


def _row(
    problem_id: str,
    category: str,
    problem_text: str,
    ground_truth_value: float,
    ground_truth_unit: str,
    formula_name: str,
    parameters: dict[str, object],
    correct_solution_text: str,
) -> dict[str, object]:
    return {
        "problem_id": problem_id,
        "category": category,
        "problem_text": problem_text,
        "ground_truth_value": float(ground_truth_value),
        "ground_truth_unit": ground_truth_unit,
        "formula_name": formula_name,
        "parameters_json": to_json(parameters),
        "correct_solution_text": correct_solution_text,
    }
