"""Candidate solution generation and deterministic corruptions."""

from __future__ import annotations

import copy
from typing import Any

import numpy as np
import pandas as pd

from .utils import CANDIDATE_TYPES, format_number, from_json, to_json


def generate_candidates(problems: pd.DataFrame, seed: int) -> pd.DataFrame:
    """Generate correct and corrupted candidate solutions for each problem."""

    rng = np.random.default_rng(seed + 991)
    rows: list[dict[str, object]] = []
    for _, problem in problems.iterrows():
        for candidate_type in CANDIDATE_TYPES:
            value, unit, text = _candidate_value_unit_text(problem, candidate_type, rng)
            steps = _build_steps(problem, candidate_type, value, unit)
            rows.append(
                {
                    "problem_id": problem["problem_id"],
                    "candidate_id": f"{problem['problem_id']}_{candidate_type}",
                    "category": problem["category"],
                    "candidate_type": candidate_type,
                    "candidate_value": float(value),
                    "candidate_unit": unit,
                    "candidate_solution_text": text,
                    "is_valid": candidate_type == "correct",
                    "structured_steps_json": to_json(steps),
                }
            )
    return pd.DataFrame(rows)


def _candidate_value_unit_text(
    problem: pd.Series, candidate_type: str, rng: np.random.Generator
) -> tuple[float, str, str]:
    truth = float(problem["ground_truth_value"])
    truth_unit = str(problem["ground_truth_unit"])
    category = str(problem["category"])
    params = from_json(problem["parameters_json"], {})

    if candidate_type == "correct":
        return truth, truth_unit, str(problem["correct_solution_text"])

    if candidate_type == "wrong_arithmetic":
        factor = float(rng.choice([0.1, 0.5, 1.5, 2.0, 10.0]))
        value = truth * factor
        return value, truth_unit, _solution_text("Arithmetic slip", value, truth_unit)

    if candidate_type == "wrong_formula":
        value = _wrong_formula_value(category, params, truth)
        return value, truth_unit, _solution_text("Wrong formula", value, truth_unit)

    if candidate_type == "wrong_unit":
        unit = _incompatible_unit(truth_unit)
        return truth, unit, _solution_text("Wrong unit", truth, unit)

    if candidate_type == "missing_conversion":
        value = _missing_conversion_value(category, params, truth)
        return value, truth_unit, _solution_text("Missing conversion", value, truth_unit)

    unit = _plausible_wrong_unit(truth_unit)
    return truth, unit, _solution_text("Plausible scalar with wrong unit", truth, unit)


def _solution_text(prefix: str, value: float, unit: str) -> str:
    return f"{prefix}: the candidate reports {format_number(value)} {unit}."


def _wrong_formula_value(category: str, params: dict[str, Any], truth: float) -> float:
    if category == "Dosage":
        return params["weight_kg"] / params["dose_mg_per_kg"]
    if category == "Dilution":
        return params["stock_concentration_mM"] * params["final_volume_mL"] / params["target_concentration_mM"]
    if category == "Cell count":
        return params["density_cells_per_mL"] / params["volume_mL"]
    if category == "Half-life decay":
        return params["initial_concentration_mg_per_L"] * (
            0.5 ** (params["half_life_hr"] / params["elapsed_hr"])
        )
    if category == "Exponential growth":
        return params["initial_cells"] * (2 ** (params["doubling_time_hr"] / params["elapsed_hr"]))
    if category == "Flow rate":
        return params["flow_mL_per_hr"] / params["elapsed_hr"]
    if category == "Imaging scale area":
        return (params["width_px"] + params["height_px"]) * params["pixel_size_mm_per_pixel"]
    if category == "Bioink weight/volume percent":
        return params["percent_wv"] / params["volume_mL"]
    if category == "Molarity":
        return params["concentration_mM"] / params["volume_mL"]
    if category == "Unit conversion":
        source = params["source_value"]
        if params["conversion_case"] == "hours_to_minutes":
            return source / 60
        return source * 1000
    return truth * 2


def _missing_conversion_value(category: str, params: dict[str, Any], truth: float) -> float:
    if category == "Dosage":
        return params["weight_kg"] * 1000 * params["dose_mg_per_kg"]
    if category == "Dilution":
        return params["target_concentration_mM"] * (params["final_volume_mL"] / 1000) / params[
            "stock_concentration_mM"
        ]
    if category == "Cell count":
        return params["density_cells_per_mL"] * (params["volume_mL"] / 1000)
    if category == "Half-life decay":
        return params["initial_concentration_mg_per_L"] * (
            0.5 ** ((params["elapsed_hr"] * 60) / params["half_life_hr"])
        )
    if category == "Exponential growth":
        return params["initial_cells"] * (2 ** ((params["elapsed_hr"] * 60) / params["doubling_time_hr"]))
    if category == "Flow rate":
        return params["flow_mL_per_hr"] * (params["elapsed_hr"] * 60)
    if category == "Imaging scale area":
        return params["width_px"] * params["height_px"] * params["pixel_size_mm_per_pixel"]
    if category == "Bioink weight/volume percent":
        return params["percent_wv"] * params["volume_mL"]
    if category == "Molarity":
        return params["concentration_mM"] * params["volume_mL"]
    if category == "Unit conversion":
        return params["source_value"]
    return truth * 10


def _incompatible_unit(truth_unit: str) -> str:
    return {
        "mg": "mL",
        "mL": "mg",
        "cells": "mg",
        "mg/L": "cells",
        "mm^2": "mL",
        "g": "mL",
        "mmol": "mg",
        "L": "mg",
        "min": "mg",
    }.get(truth_unit, "mL")


def _plausible_wrong_unit(truth_unit: str) -> str:
    return {
        "mg": "g",
        "mL": "L",
        "cells": "kcell",
        "mg/L": "g/L",
        "mm^2": "cm^2",
        "g": "mg",
        "mmol": "mol",
        "L": "mL",
        "min": "hr",
    }.get(truth_unit, "g")


def _build_steps(problem: pd.Series, candidate_type: str, value: float, unit: str) -> list[dict[str, object]]:
    if candidate_type == "wrong_formula":
        return _wrong_formula_steps(problem, value, unit)

    steps = _correct_steps(problem)
    return _override_final_result(steps, value, unit)


def _override_final_result(steps: list[dict[str, object]], value: float, unit: str) -> list[dict[str, object]]:
    copied = copy.deepcopy(steps)
    if not copied:
        return copied
    copied[-1]["result_value"] = float(value)
    copied[-1]["result_unit"] = unit
    return copied


def _correct_steps(problem: pd.Series) -> list[dict[str, object]]:
    category = str(problem["category"])
    truth = float(problem["ground_truth_value"])
    unit = str(problem["ground_truth_unit"])
    params = from_json(problem["parameters_json"], {})

    if category == "Dosage":
        return [
            _binary(
                "multiply",
                params["weight_kg"],
                "kg",
                params["dose_mg_per_kg"],
                "mg/kg",
                truth,
                unit,
            )
        ]

    if category == "Dilution":
        ratio = params["target_concentration_mM"] / params["stock_concentration_mM"]
        return [
            _binary(
                "divide",
                params["target_concentration_mM"],
                "mM",
                params["stock_concentration_mM"],
                "mM",
                ratio,
                "dimensionless",
            ),
            _binary("multiply", ratio, "dimensionless", params["final_volume_mL"], "mL", truth, unit),
        ]

    if category == "Cell count":
        return [
            _binary(
                "multiply",
                params["density_cells_per_mL"],
                "cells/mL",
                params["volume_mL"],
                "mL",
                truth,
                unit,
            )
        ]

    if category == "Half-life decay":
        ratio = params["elapsed_hr"] / params["half_life_hr"]
        fraction = 0.5**ratio
        return [
            _binary("divide", params["elapsed_hr"], "hr", params["half_life_hr"], "hr", ratio, "dimensionless"),
            _power(0.5, "dimensionless", ratio, "dimensionless", fraction, "dimensionless"),
            _binary("multiply", params["initial_concentration_mg_per_L"], "mg/L", fraction, "dimensionless", truth, unit),
        ]

    if category == "Exponential growth":
        ratio = params["elapsed_hr"] / params["doubling_time_hr"]
        growth = 2**ratio
        return [
            _binary("divide", params["elapsed_hr"], "hr", params["doubling_time_hr"], "hr", ratio, "dimensionless"),
            _power(2, "dimensionless", ratio, "dimensionless", growth, "dimensionless"),
            _binary("multiply", params["initial_cells"], "cells", growth, "dimensionless", truth, unit),
        ]

    if category == "Flow rate":
        return [_binary("multiply", params["flow_mL_per_hr"], "mL/hr", params["elapsed_hr"], "hr", truth, unit)]

    if category == "Imaging scale area":
        width_mm = params["width_px"] * params["pixel_size_mm_per_pixel"]
        height_mm = params["height_px"] * params["pixel_size_mm_per_pixel"]
        return [
            _binary("multiply", params["width_px"], "pixel", params["pixel_size_mm_per_pixel"], "mm/pixel", width_mm, "mm"),
            _binary(
                "multiply",
                params["height_px"],
                "pixel",
                params["pixel_size_mm_per_pixel"],
                "mm/pixel",
                height_mm,
                "mm",
            ),
            _binary("multiply", width_mm, "mm", height_mm, "mm", truth, unit),
        ]

    if category == "Bioink weight/volume percent":
        concentration = params["percent_wv"] / 100
        return [
            _binary("divide", params["percent_wv"], "g", 100, "mL", concentration, "g/mL"),
            _binary("multiply", concentration, "g/mL", params["volume_mL"], "mL", truth, unit),
        ]

    if category == "Molarity":
        return [_binary("multiply", params["concentration_mM"], "mM", params["volume_mL"], "mL", truth, unit)]

    if category == "Unit conversion":
        return [
            {
                "op": "convert",
                "input_value": float(params["source_value"]),
                "input_unit": params["source_unit"],
                "result_value": truth,
                "result_unit": unit,
            }
        ]

    return []


def _wrong_formula_steps(problem: pd.Series, value: float, unit: str) -> list[dict[str, object]]:
    category = str(problem["category"])
    params = from_json(problem["parameters_json"], {})

    if category == "Dosage":
        return [_binary("divide", params["weight_kg"], "kg", params["dose_mg_per_kg"], "mg/kg", value, unit)]

    if category == "Dilution":
        ratio = params["stock_concentration_mM"] / params["target_concentration_mM"]
        return [
            _binary(
                "divide",
                params["stock_concentration_mM"],
                "mM",
                params["target_concentration_mM"],
                "mM",
                ratio,
                "dimensionless",
            ),
            _binary("multiply", ratio, "dimensionless", params["final_volume_mL"], "mL", value, unit),
        ]

    if category == "Cell count":
        return [_binary("divide", params["density_cells_per_mL"], "cells/mL", params["volume_mL"], "mL", value, unit)]

    if category == "Half-life decay":
        ratio = params["half_life_hr"] / params["elapsed_hr"]
        fraction = 0.5**ratio
        return [
            _binary("divide", params["half_life_hr"], "hr", params["elapsed_hr"], "hr", ratio, "dimensionless"),
            _power(0.5, "dimensionless", ratio, "dimensionless", fraction, "dimensionless"),
            _binary("multiply", params["initial_concentration_mg_per_L"], "mg/L", fraction, "dimensionless", value, unit),
        ]

    if category == "Exponential growth":
        ratio = params["doubling_time_hr"] / params["elapsed_hr"]
        growth = 2**ratio
        return [
            _binary("divide", params["doubling_time_hr"], "hr", params["elapsed_hr"], "hr", ratio, "dimensionless"),
            _power(2, "dimensionless", ratio, "dimensionless", growth, "dimensionless"),
            _binary("multiply", params["initial_cells"], "cells", growth, "dimensionless", value, unit),
        ]

    if category == "Flow rate":
        return [_binary("divide", params["flow_mL_per_hr"], "mL/hr", params["elapsed_hr"], "hr", value, unit)]

    if category == "Imaging scale area":
        span = params["width_px"] + params["height_px"]
        return [
            _binary("add", params["width_px"], "pixel", params["height_px"], "pixel", span, "pixel"),
            _binary("multiply", span, "pixel", params["pixel_size_mm_per_pixel"], "mm/pixel", value, unit),
        ]

    if category == "Bioink weight/volume percent":
        return [_binary("divide", params["percent_wv"], "g", params["volume_mL"], "mL", value, unit)]

    if category == "Molarity":
        return [_binary("divide", params["concentration_mM"], "mM", params["volume_mL"], "mL", value, unit)]

    if category == "Unit conversion":
        return [
            {
                "op": "convert",
                "input_value": float(params["source_value"]),
                "input_unit": params["source_unit"],
                "result_value": float(value),
                "result_unit": unit,
            }
        ]

    return _override_final_result(_correct_steps(problem), value, unit)


def _binary(
    op: str,
    left_value: float,
    left_unit: str,
    right_value: float,
    right_unit: str,
    result_value: float,
    result_unit: str,
) -> dict[str, object]:
    return {
        "op": op,
        "left_value": float(left_value),
        "left_unit": left_unit,
        "right_value": float(right_value),
        "right_unit": right_unit,
        "result_value": float(result_value),
        "result_unit": result_unit,
    }


def _power(
    base_value: float,
    base_unit: str,
    exponent_value: float,
    exponent_unit: str,
    result_value: float,
    result_unit: str,
) -> dict[str, object]:
    return {
        "op": "power",
        "base_value": float(base_value),
        "base_unit": base_unit,
        "exponent_value": float(exponent_value),
        "exponent_unit": exponent_unit,
        "result_value": float(result_value),
        "result_unit": result_unit,
    }
