from __future__ import annotations

import argparse
import json
import random
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List

try:
    from .dri_data import build_datasets as build_dri_datasets
except ImportError:
    from dri_data import build_datasets as build_dri_datasets

BASE_SAMPLE: Dict[str, Any] = {
    "sex": "Female",
    "age": 29,
    "age_unit": "years",
    "weight": 62,
    "weight_unit": "kg",
    "height": 165,
    "height_unit": "cm",
    "activity_level": "Low active",
    "is_pregnant": False,
    "lactation_status": "none",
}

CUISINES = [
    "indian",
    "chinese",
    "mexican",
    "thai",
    "japanese",
    "mediterranean",
    "italian",
    "korean",
    "american",
    "middle eastern",
]

FIXED_SEED = 42
MEAL_PLAN_DAYS_OPTIONS = (3, 5, 7)
PRACTICAL_MFP_TARGET = 320
MFP_INPUT_KEYS = (
    "sex",
    "age",
    "weight",
    "height",
    "activity_level",
    "goal",
    "weekly_rate_lbs",
    "carb_fat_preference",
    "weight_unit",
    "height_unit",
    "height_inches",
)
MFP_ACTIVITY_LEVEL_OPTIONS = (
    "sedentary",
    "lightly active",
    "moderately active",
    "very active",
    "extra active",
)
MFP_GOAL_OPTIONS = ("lose weight", "maintain", "gain weight")
MFP_CARB_FAT_PREFERENCE_OPTIONS = (
    "balanced",
    "lower carb",
    "higher carb",
    "higher protein",
)
MFP_AGE_BOUNDARY_POINTS = (15.0, 18.0, 25.0, 40.0, 65.0)
SUITE_OUTPUT_FILE_STEMS = {
    "mfp_macro": "mfp",
}


# Shared prompt, normalization, and unit-conversion helpers.
def _to_float(value: Any) -> float:
    return float(value)


def _age_years_and_months(age: Any, age_unit: str) -> tuple[int, int]:
    age_value = _to_float(age)
    normalized_unit = age_unit.strip().lower()
    age_months = age_value if normalized_unit == "months" else age_value * 12.0
    rounded_months = max(0, int(round(age_months)))
    return rounded_months // 12, rounded_months % 12


def _person_label(sex: str, age: Any, age_unit: str) -> str:
    normalized = sex.strip().lower()
    years, _ = _age_years_and_months(age, age_unit)

    if years < 2:
        if normalized == "female":
            return "baby girl"
        if normalized == "male":
            return "baby boy"
        return "baby"

    if years < 18:
        if normalized == "female":
            return "girl"
        if normalized == "male":
            return "boy"
        return "child"

    if normalized == "female":
        return "woman"
    if normalized == "male":
        return "man"
    return "person"


def _format_height_phrase(sample: Dict[str, Any]) -> str:
    height = sample.get("height")
    height_unit = str(sample.get("height_unit", "cm"))
    if height_unit == "ft_in":
        height_inches = sample.get("height_inches")
        if height_inches is None:
            return f"{height} ft"
        return f"{height} ft {height_inches} in"
    return f"{height} {height_unit}"


def _format_weekly_rate_lbs(value: Any) -> str:
    rate = _to_float(value)
    if rate.is_integer():
        return str(int(rate))
    return str(rate)


def _has_mixed_age(age: Any, age_unit: str) -> bool:
    _, remaining_months = _age_years_and_months(age, age_unit)
    return remaining_months > 0


def _format_age_phrase(age: Any, age_unit: str) -> str:
    years, remaining_months = _age_years_and_months(age, age_unit)
    rounded_months = (years * 12) + remaining_months

    if rounded_months < 24:
        month_word = "month" if rounded_months == 1 else "months"
        return f"{rounded_months}-{month_word}-old"

    if remaining_months == 0:
        return f"{years}-year-old"

    year_word = "year" if years == 1 else "years"
    month_word = "month" if remaining_months == 1 else "months"
    return f"{years} {year_word} and {remaining_months} {month_word} old"


def _build_mfp_goal_sentence(sample: Dict[str, Any]) -> str | None:
    goal = sample.get("goal")
    weekly_rate_lbs = sample.get("weekly_rate_lbs")
    carb_fat_preference = sample.get("carb_fat_preference")

    if goal is None and weekly_rate_lbs is None and carb_fat_preference is None:
        return None

    sentence_parts: List[str] = []

    if goal is not None:
        goal_text = str(goal).strip().lower()
        if goal_text == "lose weight":
            rate_text = _format_weekly_rate_lbs(weekly_rate_lbs if weekly_rate_lbs is not None else 0.0)
            sentence_parts.append(f"I want to lose weight at about {rate_text} lb per week.")
        elif goal_text == "gain weight":
            rate_text = _format_weekly_rate_lbs(weekly_rate_lbs if weekly_rate_lbs is not None else 0.0)
            sentence_parts.append(f"I want to gain weight at about {rate_text} lb per week.")
        elif goal_text == "maintain":
            if weekly_rate_lbs is not None and _to_float(weekly_rate_lbs) != 0.0:
                rate_text = _format_weekly_rate_lbs(weekly_rate_lbs)
                sentence_parts.append(
                    f"I want to maintain my weight, with no planned weekly change (currently set to {rate_text} lb per week)."
                )
            else:
                sentence_parts.append("I want to maintain my weight, with no planned weekly change.")
        else:
            sentence_parts.append(f"My goal is {goal}.")
    elif weekly_rate_lbs is not None:
        rate_text = _format_weekly_rate_lbs(weekly_rate_lbs)
        sentence_parts.append(f"I'm targeting a weekly weight change of {rate_text} lb.")

    if carb_fat_preference is not None:
        sentence_parts.append(f"I'd like a {carb_fat_preference} macro split.")

    if not sentence_parts:
        return None
    return " ".join(sentence_parts)


def build_profile_sentence(sample: Dict[str, Any]) -> str:
    sex = str(sample.get("sex", ""))
    age = sample.get("age")
    age_unit = str(sample.get("age_unit", "years"))
    weight = sample.get("weight")
    weight_unit = str(sample.get("weight_unit", "kg"))
    height_phrase = _format_height_phrase(sample)
    activity_level = str(sample.get("activity_level", "Unknown"))
    food_type = str(sample.get("food_type", "unspecified"))
    meal_plan_days = int(sample.get("meal_plan_days", 3))
    is_pregnant = bool(sample.get("is_pregnant", False))
    lactation_status = str(sample.get("lactation_status", "none"))
    age_phrase = _format_age_phrase(age, age_unit)
    person_label = _person_label(sex, age, age_unit)

    if _has_mixed_age(age, age_unit):
        first_sentence = f"I'm a {person_label} who is {age_phrase}."
    else:
        first_sentence = f"I'm a {age_phrase} {person_label}."

    sentence_parts = [
        first_sentence,
        f"I weigh {weight} {weight_unit} and I'm {height_phrase} tall.",
        f"My activity level is {activity_level}.",
        f"I want to eat {food_type} food.",
    ]

    if is_pregnant:
        gestation_weeks = sample.get("gestation_weeks")
        prepregnancy_weight = sample.get("prepregnancy_weight")
        prepregnancy_weight_unit = str(sample.get("prepregnancy_weight_unit", "kg"))
        sentence_parts.append("I'm pregnant.")
        if gestation_weeks is not None:
            sentence_parts.append(f"I'm at {gestation_weeks} weeks of gestation.")
        if prepregnancy_weight is not None:
            sentence_parts.append(
                f"My pre-pregnancy weight was {prepregnancy_weight} {prepregnancy_weight_unit}."
            )

    if lactation_status != "none":
        sentence_parts.append(f"I'm breastfeeding ({lactation_status}).")

    mfp_goal_sentence = _build_mfp_goal_sentence(sample)
    if mfp_goal_sentence is not None:
        sentence_parts.append(mfp_goal_sentence)

    if mfp_goal_sentence is not None:
        sentence_parts.append(
            f"Tell me {meal_plan_days} day meal plan which meets my calorie and macro targets."
        )
    else:
        sentence_parts.append(
            f"Tell me {meal_plan_days} day meal plan which meets my calorie and nutrient targets."
        )

    return " ".join(sentence_parts)


def _round1(value: float) -> float:
    return round(value, 1)


def _ensure_code_root_on_path() -> None:
    code_root = Path(__file__).resolve().parents[2]
    code_root_str = str(code_root)
    if code_root_str not in sys.path:
        sys.path.insert(0, code_root_str)


def _kg_to_lb(value_kg: float) -> float:
    return value_kg / 0.45359


def _lb_to_kg(value_lb: float) -> float:
    return value_lb * 0.45359


def _cm_to_ft_inches(value_cm: float) -> tuple[int, float]:
    total_inches = value_cm / 2.54
    feet = int(total_inches // 12)
    inches = round(total_inches - (feet * 12), 1)
    if inches >= 12.0:
        feet += 1
        inches = 0.0
    return feet, inches


def _ft_inches_to_cm(feet: float, inches: float) -> float:
    return (feet + (inches / 12.0)) * 30.48


def _convert_weight_unit(sample: Dict[str, Any], target_unit: str) -> Dict[str, Any]:
    converted = dict(sample)
    source_unit = str(converted.get("weight_unit", "kg")).lower()
    weight = _to_float(converted.get("weight", 0.0))
    weight_kg = weight if source_unit == "kg" else _lb_to_kg(weight)

    if target_unit == "kg":
        converted["weight"] = _round1(weight_kg)
        converted["weight_unit"] = "kg"
    else:
        converted["weight"] = _round1(_kg_to_lb(weight_kg))
        converted["weight_unit"] = "lb"

    if "prepregnancy_weight" in converted:
        prepreg = _to_float(converted.get("prepregnancy_weight", 0.0))
        prepreg_unit = str(converted.get("prepregnancy_weight_unit", "kg")).lower()
        prepreg_kg = prepreg if prepreg_unit == "kg" else _lb_to_kg(prepreg)
        if target_unit == "kg":
            converted["prepregnancy_weight"] = _round1(prepreg_kg)
            converted["prepregnancy_weight_unit"] = "kg"
        else:
            converted["prepregnancy_weight"] = _round1(_kg_to_lb(prepreg_kg))
            converted["prepregnancy_weight_unit"] = "lb"

    return converted


def _convert_height_unit(sample: Dict[str, Any], target_unit: str) -> Dict[str, Any]:
    converted = dict(sample)
    source_unit = str(converted.get("height_unit", "cm")).lower()

    if source_unit == "ft_in":
        feet = _to_float(converted.get("height", 0.0))
        inches = _to_float(converted.get("height_inches", 0.0))
        height_cm = _ft_inches_to_cm(feet, inches)
    else:
        height_cm = _to_float(converted.get("height", 0.0))

    if target_unit == "cm":
        converted["height"] = _round1(height_cm)
        converted["height_unit"] = "cm"
        converted.pop("height_inches", None)
    else:
        feet, inches = _cm_to_ft_inches(height_cm)
        converted["height"] = feet
        converted["height_unit"] = "ft_in"
        converted["height_inches"] = inches

    return converted


def _convert_age_unit(sample: Dict[str, Any], target_unit: str) -> Dict[str, Any]:
    converted = dict(sample)
    source_unit = str(converted.get("age_unit", "years")).lower()
    age_value = _to_float(converted.get("age", 0.0))
    age_years = age_value / 12.0 if source_unit == "months" else age_value

    if target_unit == "years":
        converted["age"] = round(age_years, 3)
        converted["age_unit"] = "years"
    else:
        converted["age"] = round(age_years * 12.0, 2)
        converted["age_unit"] = "months"

    return converted


def _normalize_age_to_years(sample: Dict[str, Any]) -> Dict[str, Any]:
    return _convert_age_unit(sample, "years")


def _apply_prompt_fields(sample: Dict[str, Any], rng: random.Random) -> None:
    sample["food_type"] = rng.choice(CUISINES)
    sample["meal_plan_days"] = rng.choice(MEAL_PLAN_DAYS_OPTIONS)


def _to_mfp_input(sample: Dict[str, Any]) -> Dict[str, Any]:
    payload: Dict[str, Any] = {}
    for key in MFP_INPUT_KEYS:
        if key in sample and sample[key] is not None:
            payload[key] = sample[key]
    return payload


def _validate_with_mfp_macro_calculator(
    calculate_mfp_macros: Any,
    sample: Dict[str, Any],
    expected_valid: bool,
    scenario_id: str,
) -> None:
    result = calculate_mfp_macros(**_to_mfp_input(sample))
    got_error = isinstance(result, str) and result.startswith("Error:")
    if expected_valid and got_error:
        raise ValueError(f"Scenario '{scenario_id}' should be valid, but calculate_mfp_macros returned: {result}")
    if not expected_valid and not got_error:
        raise ValueError(
            f"Scenario '{scenario_id}' should be invalid, but calculate_mfp_macros returned a non-error response"
        )


def _suite_row(
    sample: Dict[str, Any],
    suite: str,
    scenario_id: str,
) -> Dict[str, Any]:
    row = dict(sample)
    row["suite"] = suite
    row["scenario_id"] = scenario_id
    return row


def _mfp_input_fingerprint(sample: Dict[str, Any]) -> str:
    return json.dumps(_to_mfp_input(sample), sort_keys=True)


def _prepare_suite_row_sample(sample: Dict[str, Any], rng: random.Random) -> Dict[str, Any]:
    row = dict(sample)
    _apply_prompt_fields(row, rng)
    row["profile_sentence"] = build_profile_sentence(row)
    return row


def _try_add_validated_suite_row(
    rows: List[Dict[str, Any]],
    seen_fingerprints: set[str],
    sample: Dict[str, Any],
    *,
    rng: random.Random,
    suite: str,
    scenario_id: str,
    fingerprint: str,
    validate_sample: Callable[[Dict[str, Any], str], None],
    on_added: Callable[[Dict[str, Any]], None] | None = None,
) -> bool:
    if fingerprint in seen_fingerprints:
        return False

    row = _prepare_suite_row_sample(sample, rng)
    validate_sample(row, scenario_id)
    rows.append(_suite_row(row, suite=suite, scenario_id=scenario_id))
    seen_fingerprints.add(fingerprint)

    if on_added is not None:
        on_added(row)

    return True


def _base_input(**overrides: Any) -> Dict[str, Any]:
    sample = dict(BASE_SAMPLE)
    sample.update(
        {
            "age_unit": "years",
            "weight_unit": "kg",
            "height_unit": "cm",
            "is_pregnant": False,
            "lactation_status": "none",
        }
    )
    sample.update(overrides)
    return sample


def _scenario(scenario_id: str, sample: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "scenario_id": scenario_id,
        "sample": sample,
    }


# MFP macro sample generation and validation.
def _mfp_default_body_metrics(sex: str, age: float) -> tuple[float, float]:
    sex_key = sex.strip().lower()
    age_value = _to_float(age)

    if age_value < 19:
        return (68.0, 175.0) if sex_key == "male" else (58.0, 163.0)
    if age_value < 30:
        return (78.0, 179.0) if sex_key == "male" else (64.0, 166.0)
    if age_value < 50:
        return (84.0, 177.0) if sex_key == "male" else (69.0, 165.0)
    return (80.0, 174.0) if sex_key == "male" else (67.0, 162.0)


def _mfp_weekly_rates_for_goal(goal: str) -> List[float]:
    goal_key = goal.strip().lower()
    if goal_key == "maintain":
        return [0.0]
    return [0.5, 1.0, 2.0]


def _build_mfp_base_sample(
    sex: str,
    age: float,
    weight_kg: float,
    height_cm: float,
    activity_level: str,
    goal: str,
    weekly_rate_lbs: float,
    carb_fat_preference: str,
) -> Dict[str, Any]:
    return {
        "sex": sex,
        "age": round(_to_float(age), 1),
        "age_unit": "years",
        "weight": _round1(weight_kg),
        "weight_unit": "kg",
        "height": _round1(height_cm),
        "height_unit": "cm",
        "activity_level": activity_level,
        "goal": goal,
        "weekly_rate_lbs": weekly_rate_lbs,
        "carb_fat_preference": carb_fat_preference,
    }


def _generate_random_mfp_sample(rng: random.Random, sex: str, goal: str) -> Dict[str, Any]:
    if rng.random() < 0.35:
        age = rng.choice(MFP_AGE_BOUNDARY_POINTS)
    else:
        age = round(rng.uniform(15.0, 80.0), 1)

    default_weight_kg, default_height_cm = _mfp_default_body_metrics(sex, age)

    if rng.random() < 0.25:
        weight_kg = rng.choice([50.0, 70.0, 100.0, 120.0])
    else:
        weight_kg = round(max(45.0, min(140.0, rng.normalvariate(default_weight_kg, 12.0))), 1)

    if rng.random() < 0.25:
        height_cm = rng.choice([150.0, 160.0, 170.0, 180.0, 190.0])
    else:
        height_cm = round(max(145.0, min(200.0, rng.normalvariate(default_height_cm, 8.0))), 1)

    weekly_rate_lbs = 0.0
    if goal != "maintain":
        if rng.random() < 0.40:
            weekly_rate_lbs = rng.choice([0.5, 1.0, 2.0])
        else:
            weekly_rate_lbs = round(rng.choice([0.25, 0.5, 0.75, 1.0, 1.5, 2.0]), 2)

    sample = _build_mfp_base_sample(
        sex=sex,
        age=age,
        weight_kg=weight_kg,
        height_cm=height_cm,
        activity_level=rng.choice(MFP_ACTIVITY_LEVEL_OPTIONS),
        goal=goal,
        weekly_rate_lbs=weekly_rate_lbs,
        carb_fat_preference=rng.choice(MFP_CARB_FAT_PREFERENCE_OPTIONS),
    )

    if rng.random() < 0.50:
        sample = _convert_weight_unit(sample, "lb")
    if rng.random() < 0.50:
        sample = _convert_height_unit(sample, "ft_in")

    return sample


def _build_mfp_macro_scenarios(
    rng: random.Random,
    calculate_mfp_macros: Any,
    target_count: int,
) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    seen_fingerprints: set[str] = set()
    stratum_counts: Dict[tuple[str, str], int] = {
        (sex, goal): 0 for sex in ("male", "female") for goal in MFP_GOAL_OPTIONS
    }

    def validate_sample(row: Dict[str, Any], scenario_id: str) -> None:
        _validate_with_mfp_macro_calculator(
            calculate_mfp_macros=calculate_mfp_macros,
            sample=row,
            expected_valid=True,
            scenario_id=scenario_id,
        )

    def on_added(row: Dict[str, Any]) -> None:
        stratum_key = (str(row["sex"]).strip().lower(), str(row["goal"]).strip().lower())
        stratum_counts[stratum_key] = stratum_counts.get(stratum_key, 0) + 1

    def add_sample(sample: Dict[str, Any], scenario_id: str) -> bool:
        if len(rows) >= target_count:
            return False

        fingerprint = _mfp_input_fingerprint(sample)
        return _try_add_validated_suite_row(
            rows,
            seen_fingerprints,
            sample,
            rng=rng,
            suite="mfp_macro",
            scenario_id=scenario_id,
            fingerprint=fingerprint,
            validate_sample=validate_sample,
            on_added=on_added,
        )

    seed_index = 1

    for sex in ("male", "female"):
        for age in MFP_AGE_BOUNDARY_POINTS:
            base_weight_kg, base_height_cm = _mfp_default_body_metrics(sex, age)
            for activity_index, activity_level in enumerate(MFP_ACTIVITY_LEVEL_OPTIONS):
                sample = _build_mfp_base_sample(
                    sex=sex,
                    age=age,
                    weight_kg=base_weight_kg,
                    height_cm=base_height_cm,
                    activity_level=activity_level,
                    goal="maintain",
                    weekly_rate_lbs=0.0,
                    carb_fat_preference="balanced",
                )
                if activity_index % 2 == 1:
                    sample = _convert_weight_unit(sample, "lb")
                if activity_index % 3 == 1:
                    sample = _convert_height_unit(sample, "ft_in")
                add_sample(sample, f"mfp_macro/seed_{seed_index:04d}")
                seed_index += 1

    for sex in ("male", "female"):
        base_weight_kg, base_height_cm = _mfp_default_body_metrics(sex, 40.0)
        for goal in MFP_GOAL_OPTIONS:
            for weekly_rate_lbs in _mfp_weekly_rates_for_goal(goal):
                for preference_index, carb_fat_preference in enumerate(MFP_CARB_FAT_PREFERENCE_OPTIONS):
                    sample = _build_mfp_base_sample(
                        sex=sex,
                        age=40.0,
                        weight_kg=base_weight_kg,
                        height_cm=base_height_cm,
                        activity_level=MFP_ACTIVITY_LEVEL_OPTIONS[preference_index % len(MFP_ACTIVITY_LEVEL_OPTIONS)],
                        goal=goal,
                        weekly_rate_lbs=weekly_rate_lbs,
                        carb_fat_preference=carb_fat_preference,
                    )
                    if preference_index % 2 == 1:
                        sample = _convert_weight_unit(sample, "lb")
                    if preference_index % 2 == 0:
                        sample = _convert_height_unit(sample, "ft_in")
                    add_sample(sample, f"mfp_macro/seed_{seed_index:04d}")
                    seed_index += 1

    unit_profiles = [
        {"weight_kg": 50.0, "height_cm": 150.0, "weight_unit": "kg", "height_unit": "cm"},
        {"weight_kg": 70.0, "height_cm": 165.0, "weight_unit": "lb", "height_unit": "cm"},
        {"weight_kg": 100.0, "height_cm": 180.0, "weight_unit": "kg", "height_unit": "ft_in"},
        {"weight_kg": 120.0, "height_cm": 190.0, "weight_unit": "lb", "height_unit": "ft_in"},
    ]
    for sex in ("male", "female"):
        for age in (15.0, 65.0):
            for profile_index, profile in enumerate(unit_profiles):
                sample = _build_mfp_base_sample(
                    sex=sex,
                    age=age,
                    weight_kg=profile["weight_kg"],
                    height_cm=profile["height_cm"],
                    activity_level=MFP_ACTIVITY_LEVEL_OPTIONS[profile_index],
                    goal=MFP_GOAL_OPTIONS[profile_index % len(MFP_GOAL_OPTIONS)],
                    weekly_rate_lbs=_mfp_weekly_rates_for_goal(
                        MFP_GOAL_OPTIONS[profile_index % len(MFP_GOAL_OPTIONS)]
                    )[0],
                    carb_fat_preference=MFP_CARB_FAT_PREFERENCE_OPTIONS[profile_index],
                )
                if profile["weight_unit"] == "lb":
                    sample = _convert_weight_unit(sample, "lb")
                if profile["height_unit"] == "ft_in":
                    sample = _convert_height_unit(sample, "ft_in")
                add_sample(sample, f"mfp_macro/seed_{seed_index:04d}")
                seed_index += 1

    if len(rows) >= target_count:
        return rows[:target_count]

    strata = [(sex, goal) for sex in ("male", "female") for goal in MFP_GOAL_OPTIONS]
    target_per_stratum = target_count // len(strata)
    stratum_remainder = target_count % len(strata)
    stratum_targets: Dict[tuple[str, str], int] = {}
    for index, stratum in enumerate(strata):
        stratum_targets[stratum] = target_per_stratum + (1 if index < stratum_remainder else 0)

    attempts = 0
    max_attempts = max(target_count * 50, 10000)
    while len(rows) < target_count and attempts < max_attempts:
        attempts += 1
        pending_strata = [stratum for stratum in strata if stratum_counts[stratum] < stratum_targets[stratum]]
        if pending_strata:
            sex, goal = pending_strata[attempts % len(pending_strata)]
        else:
            sex, goal = strata[attempts % len(strata)]

        sample = _generate_random_mfp_sample(rng, sex=sex, goal=goal)
        add_sample(sample, f"mfp_macro/random_{len(rows) + 1:04d}")

    if len(rows) < target_count:
        raise RuntimeError(
            f"Unable to build requested mfp_macro suite size: requested={target_count}, built={len(rows)}"
        )

    return rows


def build_datasets(
    dri_target: int,
    fuzz_valid_target: int,
    mfp_target: int,
) -> Dict[str, List[Dict[str, Any]]]:
    rng = random.Random(FIXED_SEED)
    _ensure_code_root_on_path()
    from utils.calculate_mfp_macros import calculate_mfp_macros
    datasets: Dict[str, List[Dict[str, Any]]] = {}

    if dri_target > 0 or fuzz_valid_target > 0:
        datasets.update(
            build_dri_datasets(
                dri_target=dri_target,
                fuzz_valid_target=fuzz_valid_target,
            )
        )

    datasets["mfp_macro"] = []
    if mfp_target > 0:
        datasets["mfp_macro"] = _build_mfp_macro_scenarios(
            rng=rng,
            calculate_mfp_macros=calculate_mfp_macros,
            target_count=mfp_target,
        )

    return datasets


# Output and CLI entrypoints.
def write_outputs(output_dir: Path, datasets: Dict[str, List[Dict[str, Any]]]) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)

    total_rows = 0
    for dataset_name, rows in datasets.items():
        file_stem = SUITE_OUTPUT_FILE_STEMS.get(dataset_name, dataset_name)
        jsonl_path = output_dir / f"{file_stem}.jsonl"

        with jsonl_path.open("w", encoding="utf-8") as f:
            for row in rows:
                f.write(json.dumps(row, ensure_ascii=False) + "\n")

        total_rows += len(rows)
        print(f"Wrote {len(rows)} samples for {dataset_name}")
        print(f"JSONL: {jsonl_path}")

    print(f"Total samples written: {total_rows}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate practical evaluation samples for DRI and MFP macro suites."
    )
    parser.add_argument(
        "--dri-target",
        type=int,
        default=0,
        help="Number of DRI branch+boundary valid cases (recommended 120-180).",
    )
    parser.add_argument(
        "--fuzz-valid-target",
        type=int,
        default=0,
        help="Number of seeded random valid fuzz cases (recommended 150-300).",
    )
    parser.add_argument(
        "--mfp-target",
        type=int,
        default=PRACTICAL_MFP_TARGET,
        help="Number of boundary-plus-stratified-valid MFP macro cases (recommended 250-500).",
    )
    parser.add_argument(
        "--samples-per-scenario",
        type=int,
        default=None,
        help=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--samples-per-category",
        type=int,
        default=None,
        help=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path(__file__).resolve().parent,
        help="Directory where suite-specific category_samples_*.jsonl files are written.",
    )
    args = parser.parse_args()

    if args.samples_per_scenario is not None or args.samples_per_category is not None:
        legacy_value = args.samples_per_scenario or args.samples_per_category or 1
        datasets = build_datasets(
            dri_target=max(legacy_value, 0),
            fuzz_valid_target=0,
            mfp_target=0,
        )
    else:
        if args.dri_target < 0 or args.fuzz_valid_target < 0 or args.mfp_target < 0:
            raise ValueError("dri-target, fuzz-valid-target, and mfp-target must be >= 0")
        datasets = build_datasets(
            dri_target=args.dri_target,
            fuzz_valid_target=args.fuzz_valid_target,
            mfp_target=args.mfp_target,
        )
    write_outputs(args.output_dir, datasets)


if __name__ == "__main__":
    main()
