import importlib
import sys
import os

# Add module path
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
    sys.path.append(current_dir)

KNOWN_PUZZLES = [
    "bridges", "galaxies", "pattern", "undead",
]


def _resolve_module(data_source):
    """Resolve a data_source slug (e.g. 'anon-neurips26_bridges_5x5de_test200_intformat_json') to a puzzle module name."""
    if data_source in KNOWN_PUZZLES:
        return data_source
    for puzzle in KNOWN_PUZZLES:
        if puzzle in data_source:
            return puzzle
    return data_source


def compute_score(data_source, solution_str, ground_truth, extra_info=None, method="strict", **kwargs):
    module_name = _resolve_module(data_source)

    # Combined exact + format reward: works with any puzzle's exact match
    # method format: "exact_plus_format" or "exact_plus_format_W" where W is format weight
    if method.startswith("exact_plus_format"):
        format_weight = 0.05
        parts = method.split("_")
        if len(parts) == 4:  # exact_plus_format_0.2
            try:
                format_weight = float(parts[3])
            except ValueError:
                pass

        from reward_function.advanced.format_rewards import xmlcount_reward_func
        module = importlib.import_module(module_name)
        try:
            acc = module.compute_score(solution_str, ground_truth, extra_info=extra_info, method="exact")
        except TypeError:
            acc = module.compute_score(solution_str, ground_truth, method="exact")
        fmt = xmlcount_reward_func(solution_str=solution_str)
        score = acc + format_weight * fmt
        return {"score": score, "acc": acc, "format_xmlcount": fmt}

    # Combined partial_v2 + exact match + format reward
    if method == "partial_v2_plus_format":
        format_weight = kwargs.get("format_weight", 0.2)

        # Inject reward kwargs into extra_info (puzzle modules read from extra_info)
        extra_info = extra_info or {}
        for k in ("power_exponent", "changed_cell_weight", "reward_mode"):
            if k not in extra_info and k in kwargs:
                extra_info[k] = kwargs[k]

        # Per-puzzle exact-only override: skip partial reward for puzzles whose
        # partial signal is too gameable (e.g. galaxies — only ~6 fillable cells
        # per 7x7 grid with 50/50 binary targets, so a "guess one char everywhere"
        # strategy floors partial at ~0.20 and decouples from actual solving).
        # Comma-separated puzzle module names: e.g. "galaxies" or "galaxies,pattern".
        exact_only_puzzles = kwargs.get("exact_only_puzzles", "")
        if isinstance(exact_only_puzzles, str):
            exact_only_puzzles = [p.strip() for p in exact_only_puzzles.split(",") if p.strip()]
        else:
            exact_only_puzzles = list(exact_only_puzzles)

        from reward_function.advanced.format_rewards import xmlcount_reward_func
        module = importlib.import_module(module_name)

        if module_name in exact_only_puzzles:
            try:
                acc = module.compute_score(solution_str, ground_truth, extra_info=extra_info, method="exact")
            except TypeError:
                acc = module.compute_score(solution_str, ground_truth, method="exact")
            fmt = xmlcount_reward_func(solution_str=solution_str)
            score = acc + format_weight * fmt
            return {"score": score, "acc": acc, "format_xmlcount": fmt, "partial_score": 0.0}

        try:
            partial_score = module.compute_score(
                solution_str, ground_truth, extra_info=extra_info, method="partial_v2"
            )
        except TypeError:
            partial_score = module.compute_score(solution_str, ground_truth, method="partial_v2")

        try:
            acc = module.compute_score(
                solution_str, ground_truth, extra_info=extra_info, method="exact"
            )
        except TypeError:
            acc = module.compute_score(solution_str, ground_truth, method="exact")

        fmt = xmlcount_reward_func(solution_str=solution_str)
        score = acc + partial_score + format_weight * fmt
        return {"score": score, "acc": acc, "format_xmlcount": fmt, "partial_score": partial_score}

    try:
        module = importlib.import_module(module_name)
        # Pass extra_info to reward functions that support it (e.g., bridges)
        return module.compute_score(solution_str, ground_truth, extra_info=extra_info, method=method)
    except TypeError:
        # Fallback for reward functions that don't accept extra_info
        return module.compute_score(solution_str, ground_truth, method=method)
    except ModuleNotFoundError:
        raise NotImplementedError(f"No scoring implementation for data_source: {data_source}")

if __name__ == "__main__":
    # Test basic routing
    score = compute_score("cryptarithm", "#### 8425+1923=10347", "8425+1923=10348", {"question":"tank+wine=wheat"})
    print(f"cryptarithm: {score}")

    # Test slug resolution
    score = compute_score("anon-neurips26_bridges_5x5de_test200_intformat_json",
                          "<answer>1,2</answer>", "1,2", method="exact_plus_format_0.2")
    print(f"bridges slug (exact_plus_format_0.2): {score}")
