"""
Streamlined reward system for chemistry tasks.

This module provides reward functions for:
- Count tasks (single and multi)
- Index identification tasks (single and multi)
- Constraint generation tasks

Optionally, count tasks can return detailed per-property feedback by supplying
``return_details=True`` to ``chemical_reward`` (or the underlying count helpers).
"""

from typing import Any, Dict, Union

# Import legacy reward helpers (still available for detailed diagnostics)
from .count_reward import (
    multi_count_dict_reward,
    single_count_reward
)

from .index_reward import (
    multi_index_identification_reward,
    single_index_reward
)

from .constraint_reward import (
    multi_constraint_generation_reward,
    constraint_reward
)

# New comparator-based reward utilities
from .reward_comparators import (
    compare_answers,
    compare_index_sets,
    compare_numeric_values,
    compare_smiles,
    compare_smiles_with_constraints,
    detect_answer_type,
    extract_numeric_values,
    extract_smiles_for_reward,
    extract_values_from_dict
)

# Import general utilities
from .utils import (
    valid_smiles,
    is_reasonable_molecule,
    evaluate_numeric_constraint,
    parse_natural_language_property
)


def _basic_detail(
    *,
    reward: float,
    comparison: str,
    target: Any,
    predicted: Any,
) -> Dict[str, Any]:
    """Lightweight detail payload used when ``return_details`` is requested."""
    return {
        "reward": reward,
        "details": {
            "comparison": comparison,
            "target": target,
            "predicted": predicted,
            "match": bool(reward),
        },
        "matched": int(bool(reward)),
        "total": 1,
        "extra_predictions": {}
    }

def chemical_reward(
    task_type: str,
    predicted,
    target=None,
    constraints=None,
    *,
    return_details: bool = False,
    **kwargs
) -> Union[float, Dict[str, Any]]:
    """
    Unified dispatcher for chemical rewards based on task type.

    Args:
        task_type: Type of task - one of:
            - 'single_count': Single property counting
            - 'multi_count': Multiple property counting
            - 'single_index': Single index identification
            - 'multi_index': Multiple index identification
            - 'constraint_generation': Constraint-based molecule generation
            - 'multi_count_dict': Alias for multi_count
            - 'multi_index_identification': Alias for multi_index
            - 'multi_constraint_generation': Alias for constraint_generation
        predicted: Model prediction (format depends on task type)
        target: Ground truth (for count/index tasks)
        constraints: Constraint list (for generation tasks)
        return_details: When ``True`` the dispatcher forwards the flag to the
            underlying helper so callers can request detailed diagnostics.
            Count and index helpers expose per-property match reports, while
            constraint helpers return per-constraint status information.
        **kwargs: Additional arguments passed through to underlying helpers.

    Returns:
        Union[float, Dict[str, Any]]: Reward value (float) or detailed dictionary
        for count tasks when ``return_details`` is ``True``.
    """
    # Extract comparator-specific kwargs so they are not forwarded downstream
    none_equals_zero = bool(kwargs.pop('none_equals_zero', False))
    tolerance = kwargs.pop('tolerance', None)

    # Normalize task type for routing
    task_type_norm = task_type.lower().replace('-', '_') if task_type else ''

    # Count / numeric style tasks
    if task_type_norm in ['single_count', 'count', 'multi_count', 'multi_count_dict', 'multiple_count']:
        if target is None:
            raise ValueError("Target required for count tasks")

        comparator_kwargs = {'none_equals_zero': none_equals_zero}
        if tolerance is not None:
            comparator_kwargs['tolerance'] = tolerance

        if return_details:
            reward_value = compare_numeric_values(
                target,
                predicted,
                **comparator_kwargs,
            )
            return _basic_detail(
                reward=float(reward_value),
                comparison='numeric',
                target=target,
                predicted=predicted,
            )

        reward_value = compare_numeric_values(
            target,
            predicted,
            **comparator_kwargs,
        )
        return float(reward_value)

    # Index identification tasks
    if task_type_norm in ['single_index', 'index', 'single_index_identification',
                           'multi_index', 'multi_index_identification', 'multiple_index']:
        if target is None:
            raise ValueError("Target required for index tasks")

        if return_details:
            reward_value = compare_index_sets(
                target,
                predicted,
                none_equals_zero=none_equals_zero,
            )
            return _basic_detail(
                reward=float(reward_value),
                comparison='indices',
                target=target,
                predicted=predicted,
            )

        reward_value = compare_index_sets(
            target,
            predicted,
            none_equals_zero=none_equals_zero,
        )
        return float(reward_value)

    # Constraint generation / SMILES tasks
    if task_type_norm in [
        'constraint_generation',
        'constraint',
        'generation',
        'multi_constraint_generation',
        'multi_constraint',
        'smiles_generation',
    ]:
        constraint_list = constraints if constraints is not None else target
        if constraint_list is None:
            raise ValueError("Constraints required for generation tasks")

        comparison = compare_smiles_with_constraints(
            predicted,
            constraint_list,
            target=target if isinstance(target, str) else None,
            return_details=return_details,
        )

        if return_details:
            return comparison

        return float(comparison)

    # Fallback: rely on auto-detect comparison if task type unknown
    if target is None:
        raise ValueError(f"Unknown task type '{task_type}' with no target provided.")

    reward_value = compare_answers(
        target,
        predicted,
        none_equals_zero=none_equals_zero,
        auto_detect=True,
    )

    if return_details:
        return _basic_detail(
            reward=float(reward_value),
            comparison=f"auto:{detect_answer_type(target)}",
            target=target,
            predicted=predicted,
        )

    return float(reward_value)


# Define what's exported
__all__ = [
    # Main dispatcher
    'chemical_reward',

    # Count rewards
    'multi_count_dict_reward',
    'single_count_reward',

    # Index rewards
    'multi_index_identification_reward',
    'single_index_reward',

    # Constraint rewards
    'multi_constraint_generation_reward',
    'constraint_reward',

    # Utilities
    'valid_smiles',
    'is_reasonable_molecule',
    'evaluate_numeric_constraint',
    'parse_natural_language_property',

    # Comparator helpers
    'compare_answers',
    'compare_index_sets',
    'compare_numeric_values',
    'compare_smiles_with_constraints',
    'detect_answer_type',
    'extract_numeric_values',
    'extract_smiles_for_reward',
    'extract_values_from_dict',
]
