"""Utility functions implementing the new MolecularIQ reward comparisons.

The original implementation lived in ``chem_reasoning_baseline_eval.utils.new_rewards``
for post-hoc metric recomputation. This copy relocates the logic into the main
``lm_eval`` package so the harness can use it directly during evaluation.

The helpers focus on three comparison modes used across MolecularIQ tasks:

``numeric``
    Compare numbers or molecular formulae irrespective of dictionary keys.
``index``
    Compare collections of atom indices without relying on key names.
``smiles``
    Validate generated SMILES strings against constraint definitions using the
    existing constraint reward solver.

The module intentionally mirrors the public surface of the original utilities so
external tooling that imported them continues to function.
"""

from __future__ import annotations

import json
import re
from collections import Counter
from typing import Any, Dict, Iterable, List, Tuple, Union

from .constraint_reward import multi_constraint_generation_reward

try:  # pragma: no cover - rdkit is an optional dependency
    from rdkit import RDLogger  # type: ignore

    RDLogger.DisableLog("rdApp.*")
except Exception:  # pragma: no cover - rdkit may be unavailable
    pass

# ---------------------------------------------------------------------------
# Index comparisons
# ---------------------------------------------------------------------------


def _coerce_to_int(value: Union[int, float, str, bool]) -> int:
    """Convert supported scalar types to ``int``, enforcing integer semantics."""
    if isinstance(value, bool):
        return int(value)
    if isinstance(value, int):
        return value
    if isinstance(value, float):
        if value.is_integer():
            return int(value)
        raise ValueError(f"Non-integer float: {value}")
    if isinstance(value, str):
        stripped = value.strip()
        if not stripped or stripped.lower() in {"none", "null"}:
            raise ValueError("Empty or null string is not a valid index")
        if stripped.startswith("(") and stripped.endswith(")"):
            stripped = stripped[1:-1].strip()
        try:
            return int(stripped)
        except ValueError:
            number = float(stripped)
            if not number.is_integer():
                raise ValueError(f"Non-integer string value: {value}")
            return int(number)
    raise ValueError(f"Unsupported value type: {type(value)}")


def _convert_sequence(items: Iterable, none_equals_zero: bool) -> List[int]:
    """Convert a flat iterable of scalars into a ``List[int]`` preserving duplicates."""
    result: List[int] = []
    for item in items:
        if item is None:
            if none_equals_zero:
                result.append(0)
                continue
            raise ValueError("None value encountered in index data")
        if isinstance(item, (list, tuple)):
            result.extend(_convert_sequence(item, none_equals_zero))
            continue
        result.append(_coerce_to_int(item))
    return result


def _parse_index_json_like(text: str) -> Union[dict, list, int, float, str, None]:
    """Parse JSON-like text, normalising simple bracket variants."""
    try:
        return json.loads(text)
    except json.JSONDecodeError as exc:
        adjusted = text.replace("'", '"').replace('(', '[').replace(')', ']')
        if adjusted != text:
            try:
                return json.loads(adjusted)
            except json.JSONDecodeError as exc_inner:
                raise ValueError(f"Unable to parse string as JSON: {text}") from exc_inner
        raise ValueError(f"Unable to parse string as JSON: {text}") from exc


def extract_values_from_dict(
    data: Union[str, dict, list, int, float, None],
    none_equals_zero: bool = False,
) -> List[List[int]]:
    """Extract numeric index collections from a JSON-like structure."""
    if data is None:
        return [[0]] if none_equals_zero else [[]]

    if isinstance(data, str):
        stripped = data.strip()
        if not stripped or stripped.lower() in {"none", "null"}:
            return [[0]] if none_equals_zero else [[]]
        lowered = stripped.lower()
        if "smiles" in lowered and "[" not in stripped:
            raise ValueError("SMILES-like string provided for index comparison")
        if stripped.startswith(("{", "[")):
            parsed = _parse_index_json_like(stripped)
            return extract_values_from_dict(parsed, none_equals_zero)
        if "," in stripped:
            parts = [part.strip() for part in stripped.split(",") if part.strip()]
            if not parts:
                return [[0]] if none_equals_zero else [[]]
            converted = [_coerce_to_int(part) for part in parts]
            return [converted]
        try:
            return [[_coerce_to_int(stripped)]]
        except ValueError as exc:
            raise ValueError(f"Unsupported index string: {data}") from exc

    if isinstance(data, dict):
        if not data:
            return [[0]] if none_equals_zero else [[]]
        result: List[List[int]] = []
        for value in data.values():
            result.extend(extract_values_from_dict(value, none_equals_zero))
        return result

    if isinstance(data, (list, tuple)):
        if not data:
            return [[0]] if none_equals_zero else [[]]
        if all(isinstance(item, (list, tuple)) for item in data):
            result: List[List[int]] = []
            for item in data:
                result.extend(extract_values_from_dict(item, none_equals_zero))
            return result
        return [_convert_sequence(data, none_equals_zero)]

    if isinstance(data, (int, float, bool)):
        return [[_coerce_to_int(data)]]

    raise ValueError(f"Unsupported data type for index extraction: {type(data)}")


def _normalize_groups(groups: List[List[int]]) -> Counter:
    normalized = Counter()
    for group in groups:
        normalized[tuple(sorted(group))] += 1
    return normalized


def compare_index_sets(
    target: Union[str, dict, list, int, float, None],
    predicted: Union[str, dict, list, int, float, None],
    none_equals_zero: bool = False,
) -> int:
    """Compare two index collections, ignoring dictionary keys and order."""
    target_groups = extract_values_from_dict(target, none_equals_zero)
    try:
        predicted_groups = extract_values_from_dict(predicted, none_equals_zero)
    except ValueError:
        return 0

    return int(_normalize_groups(target_groups) == _normalize_groups(predicted_groups))


def batch_compare(
    targets: List[Union[str, dict, list, int, float, None]],
    predictions: List[Union[str, dict, list, int, float, None]],
    none_equals_zero: bool = False,
) -> List[int]:
    """Vectorised wrapper around :func:`compare_index_sets`."""
    return [compare_index_sets(t, p, none_equals_zero) for t, p in zip(targets, predictions)]


# ---------------------------------------------------------------------------
# Numeric / formula comparisons
# ---------------------------------------------------------------------------

FormulaRepresentation = Tuple[Tuple[Tuple[str, int], ...], str]
NumericToken = Tuple[str, Union[float, FormulaRepresentation]]

_FORMULA_BODY = re.compile(r"^[A-Za-z0-9]+$")
_ELEMENT_SYMBOLS = {
    'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si',
    'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
    'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo',
    'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba',
    'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb',
    'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po',
    'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf',
    'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn',
    'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
}
_SUBSCRIPT_TRANSLATION = str.maketrans('₀₁₂₃₄₅₆₇₈₉', '0123456789')


def _normalize_formula_text(text: str) -> str:
    normalized = text.translate(_SUBSCRIPT_TRANSLATION)
    normalized = normalized.replace(' ', '').replace('-', '').replace('·', '').replace('^', '')
    return normalized


def _split_formula_and_charge(text: str) -> Tuple[str, str]:
    """Split a normalised formula into element body and optional charge descriptor."""
    if not text:
        return text, ''

    idx = len(text)
    charge_chars: List[str] = []

    while idx > 0 and text[idx - 1] in '+-':
        charge_chars.append(text[idx - 1])
        idx -= 1

    if charge_chars:
        charge = ''.join(reversed(charge_chars))
        digit_idx = idx
        while digit_idx > 0 and text[digit_idx - 1].isdigit():
            charge = text[digit_idx - 1] + charge
            digit_idx -= 1
        return text[:digit_idx], _canonicalize_charge(charge)

    digit_idx = len(text)
    while digit_idx > 0 and text[digit_idx - 1].isdigit():
        digit_idx -= 1
    if digit_idx > 0 and text[digit_idx - 1] in '+-':
        charge = text[digit_idx - 1] + text[digit_idx:]
        return text[:digit_idx - 1], _canonicalize_charge(charge)

    return text, ''


def _canonicalize_charge(charge: str) -> str:
    if not charge:
        return ''
    if charge[-1] in '+-' and charge[:-1].isdigit():
        return charge
    if charge[0] in '+-' and charge[1:].isdigit():
        return charge[1:] + charge[0]
    return charge


def _parse_formula_to_token(formula: str) -> NumericToken:
    normalized = _normalize_formula_text(formula)
    if not normalized or not any(c.isalpha() for c in normalized):
        raise ValueError(f"Unsupported numeric string: {formula}")

    body, charge = _split_formula_and_charge(normalized)

    if not body or not _FORMULA_BODY.match(body):
        raise ValueError(f"Unable to parse molecular formula: {formula}")

    pattern = r'([A-Za-z][a-z]?)(\d*)'
    elements: Dict[str, int] = {}
    matches = re.findall(pattern, body)
    if not matches:
        raise ValueError(f"Unable to parse molecular formula: {formula}")
    consumed = sum(len(elem) + len(count) for elem, count in matches)
    if consumed != len(body):
        raise ValueError(f"Unable to parse molecular formula: {formula}")

    for element, count_text in matches:
        element = element[0].upper() + element[1:].lower()
        if element not in _ELEMENT_SYMBOLS:
            raise ValueError(f"Unknown element symbol in formula: {element}")
        count = int(count_text) if count_text else 1
        elements[element] = elements.get(element, 0) + count

    canonical = tuple(sorted(elements.items()))
    return ('formula', (canonical, charge))


def _parse_numeric_json_like(text: str) -> Any:
    try:
        return json.loads(text, strict=False)
    except json.JSONDecodeError as exc:
        adjusted = text.replace("'", '"')
        if adjusted != text:
            try:
                return json.loads(adjusted, strict=False)
            except json.JSONDecodeError:
                pass
        tokenized = re.sub(r'(:\s*)([A-Za-z][A-Za-z0-9+\-]*)', r'\1"\2"', text)
        if tokenized != text:
            try:
                return json.loads(tokenized, strict=False)
            except json.JSONDecodeError as exc_inner:
                raise ValueError(f"Unable to parse numeric JSON: {text}") from exc_inner
        raise ValueError(f"Unable to parse numeric JSON: {text}") from exc


def _extract_numeric_tokens(obj: Any, none_equals_zero: bool) -> List[NumericToken]:
    tokens: List[NumericToken] = []

    if obj is None:
        if none_equals_zero:
            tokens.append(('number', 0.0))
        return tokens

    if isinstance(obj, bool):
        tokens.append(('number', float(int(obj))))
        return tokens

    if isinstance(obj, (int, float)):
        tokens.append(('number', float(obj)))
        return tokens

    if isinstance(obj, str):
        stripped = obj.strip()
        if not stripped:
            if none_equals_zero:
                tokens.append(('number', 0.0))
                return tokens
            raise ValueError("Empty string is not a valid numeric value")
        lowered = stripped.lower()
        if lowered in {"none", "null"}:
            if none_equals_zero:
                tokens.append(('number', 0.0))
                return tokens
            return tokens
        if stripped.startswith("(") and stripped.endswith(")") and len(stripped) > 2:
            stripped = stripped[1:-1].strip()
        if stripped.startswith(("{", "[")):
            parsed = _parse_numeric_json_like(stripped)
            tokens.extend(_extract_numeric_tokens(parsed, none_equals_zero))
            return tokens
        try:
            tokens.append(('number', float(stripped)))
            return tokens
        except ValueError:
            tokens.append(_parse_formula_to_token(stripped))
            return tokens

    if isinstance(obj, dict):
        if not obj and none_equals_zero:
            tokens.append(('number', 0.0))
            return tokens
        for value in obj.values():
            tokens.extend(_extract_numeric_tokens(value, none_equals_zero))
        return tokens

    if isinstance(obj, (list, tuple)):
        if not obj and none_equals_zero:
            tokens.append(('number', 0.0))
            return tokens
        for item in obj:
            tokens.extend(_extract_numeric_tokens(item, none_equals_zero))
        return tokens

    raise ValueError(f"Unsupported data type for numeric extraction: {type(obj)}")


def extract_numeric_values(
    data: Union[str, dict, List, Tuple, int, float, None],
    none_equals_zero: bool = False,
) -> List[NumericToken]:
    """Extract numeric or molecular formula values from a JSON-like structure."""
    return _extract_numeric_tokens(data, none_equals_zero)


def _normalize_numeric_tokens(tokens: List[NumericToken]) -> Counter:
    return Counter(tokens)


def compare_numeric_values(
    target: Union[str, dict, List, Tuple, int, float, None],
    predicted: Union[str, dict, List, Tuple, int, float, None],
    none_equals_zero: bool = False,
    tolerance: float = 1e-9,
) -> int:
    """Compare numeric values (including formulae) with tolerance handling."""
    target_tokens = extract_numeric_values(target, none_equals_zero)
    try:
        predicted_tokens = extract_numeric_values(predicted, none_equals_zero)
    except ValueError:
        return 0

    target_numbers = sorted(value for kind, value in target_tokens if kind == 'number')
    predicted_numbers = sorted(value for kind, value in predicted_tokens if kind == 'number')

    if len(target_numbers) != len(predicted_numbers):
        return 0

    for t_val, p_val in zip(target_numbers, predicted_numbers):
        if abs(t_val - p_val) > tolerance:
            return 0

    target_formulas = Counter(token for token in target_tokens if token[0] == 'formula')
    predicted_formulas = Counter(token for token in predicted_tokens if token[0] == 'formula')

    if target_formulas != predicted_formulas:
        return 0

    return 1


def batch_compare_numeric(
    targets: List[Union[str, dict, List, Tuple, int, float, None]],
    predictions: List[Union[str, dict, List, Tuple, int, float, None]],
    none_equals_zero: bool = False,
) -> List[int]:
    """Vectorised wrapper around :func:`compare_numeric_values`."""
    return [compare_numeric_values(t, p, none_equals_zero) for t, p in zip(targets, predictions)]


# ---------------------------------------------------------------------------
# Unified comparison helpers
# ---------------------------------------------------------------------------


def detect_answer_type(data: Union[str, dict, None]) -> str:
    """Heuristically detect whether an answer contains indices, numbers, or both."""
    if data is None:
        return 'none'

    if isinstance(data, str):
        try:
            data = json.loads(data)
        except json.JSONDecodeError:
            if '[' in data and ']' in data:
                return 'indices'
            return 'numeric'

    if isinstance(data, dict):
        has_lists = False
        has_numbers = False

        for value in data.values():
            if isinstance(value, list):
                has_lists = True
            elif isinstance(value, (int, float)) or value is None:
                has_numbers = True

        if has_lists and has_numbers:
            return 'mixed'
        if has_lists:
            return 'indices'
        if has_numbers:
            return 'numeric'

    if isinstance(data, list):
        return 'indices'

    if isinstance(data, (int, float)):
        return 'numeric'

    return 'unknown'


def compare_answers(
    target: Union[str, dict, None],
    predicted: Union[str, dict, None],
    none_equals_zero: bool = False,
    auto_detect: bool = True,
    answer_type: str | None = None,
) -> int:
    """Compare answers with optional automatic type detection."""
    if answer_type is None and auto_detect:
        target_type = detect_answer_type(target)
        predicted_type = detect_answer_type(predicted)

        if target_type == 'indices' or predicted_type == 'indices':
            answer_type = 'indices'
        elif target_type == 'numeric' or predicted_type == 'numeric':
            answer_type = 'numeric'
        elif target_type == 'none' or predicted_type == 'none':
            answer_type = predicted_type if target_type == 'none' else target_type
        else:
            answer_type = 'indices'

    if answer_type == 'numeric':
        return compare_numeric_values(target, predicted, none_equals_zero)
    return compare_index_sets(target, predicted, none_equals_zero)


def batch_compare_answers(
    targets: List[Union[str, dict, None]],
    predictions: List[Union[str, dict, None]],
    none_equals_zero: bool = False,
    auto_detect: bool = True,
) -> List[int]:
    """Vectorised wrapper around :func:`compare_answers`."""
    return [
        compare_answers(t, p, none_equals_zero, auto_detect)
        for t, p in zip(targets, predictions)
    ]


# ---------------------------------------------------------------------------
# SMILES comparisons via constraint validation
# ---------------------------------------------------------------------------


def extract_smiles_for_reward(predicted: Union[str, dict, None]) -> str:
    """Extract SMILES from loose dictionary / JSON answers."""
    if predicted is None:
        return ""

    if isinstance(predicted, dict):
        if predicted:
            value = next(iter(predicted.values()))
            return str(value) if value is not None else ""
        return ""

    if isinstance(predicted, str):
        predicted = predicted.strip()

        if predicted.startswith('{') and predicted.endswith('}'):
            try:
                parsed = json.loads(predicted)
                if isinstance(parsed, dict) and parsed:
                    value = next(iter(parsed.values()))
                    return str(value) if value is not None else ""
            except (json.JSONDecodeError, ValueError):
                pass

        if predicted.lower().startswith('smiles:'):
            return predicted.split(':', 1)[1].strip()

        return predicted

    return str(predicted) if predicted is not None else ""


def compare_smiles_with_constraints(
    predicted: Union[str, dict, None],
    constraints: Union[str, List[Dict[str, Any]], Dict[str, Any], None] = None,
    target: str | None = None,
    return_details: bool = False,
) -> Union[float, Dict[str, Any]]:
    """Validate predicted SMILES using the constraint solver."""
    if constraints is None:
        raise ValueError("constraints must be provided for SMILES comparison")

    extracted_smiles = extract_smiles_for_reward(predicted)

    if isinstance(constraints, str):
        constraints = constraints.strip()
        if not constraints:
            raise ValueError("constraints must not be empty")
        try:
            constraints_obj: Union[List[Any], Dict[str, Any]] = json.loads(constraints)
        except json.JSONDecodeError as exc:
            raise ValueError("constraints string is not valid JSON") from exc
    else:
        constraints_obj = constraints  # type: ignore[assignment]

    if isinstance(constraints_obj, dict):
        constraints_list: List[Dict[str, Any]] = [constraints_obj]
    else:
        if not isinstance(constraints_obj, list):
            raise TypeError("constraints must be a list, dict, or JSON string")
        constraints_list = constraints_obj  # type: ignore[assignment]

    if not constraints_list:
        raise ValueError("constraints must not be empty")

    result = multi_constraint_generation_reward(
        extracted_smiles,
        constraints_list,
        return_details=return_details,
    )

    if return_details and isinstance(result, dict):
        result["extracted_smiles"] = extracted_smiles
        result["original_predicted"] = predicted
        if target is not None:
            result["target"] = target

    return result


def batch_compare_smiles(
    predictions: List[Union[str, dict, None]],
    constraints: Union[str, List[Dict[str, Any]], Dict[str, Any], None] = None,
    targets: List[str] | None = None,
) -> List[float]:
    """Vectorised wrapper around :func:`compare_smiles_with_constraints`."""
    rewards: List[float] = []
    for i, predicted in enumerate(predictions):
        target = targets[i] if targets and i < len(targets) else None
        reward = compare_smiles_with_constraints(predicted, constraints, target)
        if isinstance(reward, dict):
            rewards.append(float(reward.get("reward", 0.0)))
        else:
            rewards.append(float(reward))
    return rewards


def compare_smiles(predicted: Union[str, dict, None]) -> float:
    """Compatibility helper kept to mirror the legacy API."""
    raise ValueError(
        "compare_smiles requires explicit constraints. Use compare_smiles_with_constraints "
        "with a non-empty constraints list."
    )


__all__ = [
    # Index helpers
    'extract_values_from_dict',
    'compare_index_sets',
    'batch_compare',
    # Numeric helpers
    'extract_numeric_values',
    'compare_numeric_values',
    'batch_compare_numeric',
    # Unified helpers
    'detect_answer_type',
    'compare_answers',
    'batch_compare_answers',
    # SMILES helpers
    'extract_smiles_for_reward',
    'compare_smiles_with_constraints',
    'batch_compare_smiles',
    'compare_smiles',
]
