"""
Intermediate Grid Reward Functions

Provides step-wise reward functions that evaluate intermediate grids
in the solution path, not just the final answer.

Ported from an internal reward-functions module (pre-existing utility code)
"""

import re
from typing import List, Optional, Tuple


def extract_intermediate_grids(response_text: str) -> List[str]:
    """
    Extract intermediate grids from reasoning section.

    Looks for grids between <intermediate_grid> tags in the reasoning.

    Args:
        response_text: Full model response

    Returns:
        List of intermediate grid strings
    """
    pattern = r'<intermediate_grid>(.*?)</intermediate_grid>'
    grids = re.findall(pattern, response_text, re.DOTALL)
    return [g.strip() for g in grids]


def extract_solution_path_from_extra_info(extra_info: dict) -> Optional[List[str]]:
    """
    Extract ground truth solution path from extra_info.

    Args:
        extra_info: Dict containing 'solution_path' or 'intermediate_grids'

    Returns:
        List of intermediate grid strings, or None
    """
    # Try different keys
    for key in ['solution_path', 'intermediate_grids', 'path', 'steps']:
        if key in extra_info:
            path = extra_info[key]
            if isinstance(path, list):
                return path
            elif isinstance(path, str):
                # Split by delimiter if needed
                return [s.strip() for s in path.split('|||') if s.strip()]

    return None


def intermediate_grid_reward_simple(
    response_text: str,
    ground_truth: str,
    extra_info: dict,
    final_answer: str = None
) -> float:
    """
    Simple intermediate grid reward: binary match of intermediate grids.

    Compares extracted intermediate grids with ground truth solution path.
    Returns fraction of correctly matched intermediate grids.

    Args:
        response_text: Full model response (with reasoning)
        ground_truth: Final solution
        extra_info: Dict containing 'solution_path'
        final_answer: Extracted final answer (optional)

    Returns:
        Reward score (0.0 to 1.0)
    """
    # Extract intermediate grids from response
    pred_grids = extract_intermediate_grids(response_text)

    # Get ground truth solution path
    gt_path = extract_solution_path_from_extra_info(extra_info)

    if not gt_path or not pred_grids:
        # No intermediate grids to compare
        return 0.0

    # Compare grids
    matches = 0
    total = min(len(pred_grids), len(gt_path))

    for i in range(total):
        if pred_grids[i].strip() == gt_path[i].strip():
            matches += 1

    if total == 0:
        return 0.0

    return matches / total


def intermediate_grid_reward_position(
    response_text: str,
    ground_truth: str,
    extra_info: dict,
    final_answer: str = None
) -> float:
    """
    Position-weighted intermediate grid reward.

    Later grids in the solution path receive higher weight, as they
    represent more progress toward the solution.

    Weight formula: w_i = (i + 1) / N
    where i is the grid index (0-based) and N is the total number of grids.

    Args:
        response_text: Full model response (with reasoning)
        ground_truth: Final solution
        extra_info: Dict containing 'solution_path'
        final_answer: Extracted final answer (optional)

    Returns:
        Reward score (0.0 to 1.0)
    """
    # Extract intermediate grids from response
    pred_grids = extract_intermediate_grids(response_text)

    # Get ground truth solution path
    gt_path = extract_solution_path_from_extra_info(extra_info)

    if not gt_path or not pred_grids:
        return 0.0

    # Compare grids with position weighting
    weighted_score = 0.0
    total_weight = 0.0
    total = min(len(pred_grids), len(gt_path))

    for i in range(total):
        # Position weight: later grids get higher weight
        weight = (i + 1) / len(gt_path)
        total_weight += weight

        if pred_grids[i].strip() == gt_path[i].strip():
            weighted_score += weight

    if total_weight == 0:
        return 0.0

    return weighted_score / total_weight


def intermediate_grid_reward_hybrid(
    response_text: str,
    ground_truth: str,
    extra_info: dict,
    final_answer: str = None,
    simple_weight: float = 0.5,
    position_weight: float = 0.5
) -> float:
    """
    Hybrid intermediate grid reward.

    Combines simple and position-weighted rewards with configurable weights.

    Args:
        response_text: Full model response (with reasoning)
        ground_truth: Final solution
        extra_info: Dict containing 'solution_path'
        final_answer: Extracted final answer (optional)
        simple_weight: Weight for simple reward (default: 0.5)
        position_weight: Weight for position reward (default: 0.5)

    Returns:
        Reward score (0.0 to 1.0)
    """
    simple_score = intermediate_grid_reward_simple(response_text, ground_truth, extra_info, final_answer)
    position_score = intermediate_grid_reward_position(response_text, ground_truth, extra_info, final_answer)

    # Normalize weights
    total_weight = simple_weight + position_weight
    if total_weight == 0:
        return 0.0

    return (simple_score * simple_weight + position_score * position_weight) / total_weight


def intermediate_grid_reward(
    response_text: str,
    ground_truth: str,
    extra_info: dict,
    final_answer: str = None,
    method: str = "simple"
) -> float:
    """
    Main entry point for intermediate grid rewards.

    Args:
        response_text: Full model response (with reasoning)
        ground_truth: Final solution
        extra_info: Dict containing 'solution_path'
        final_answer: Extracted final answer (optional)
        method: Reward method ("simple", "position", "hybrid")

    Returns:
        Reward score (0.0 to 1.0)
    """
    if method == "simple":
        return intermediate_grid_reward_simple(response_text, ground_truth, extra_info, final_answer)
    elif method == "position":
        return intermediate_grid_reward_position(response_text, ground_truth, extra_info, final_answer)
    elif method == "hybrid":
        return intermediate_grid_reward_hybrid(response_text, ground_truth, extra_info, final_answer)
    else:
        raise ValueError(f"Unknown intermediate reward method: {method}")


# Test code
if __name__ == "__main__":
    # Test intermediate grid extraction
    test_response = """<reasoning>
Let me solve this step by step.

<intermediate_grid>
7,4,4,4,7
2,4,4,4,2
4,5,4,4,7
4,4,4,4,4
8,4,8,4,6
</intermediate_grid>

Now I'll add more bridges.

<intermediate_grid>
7,3,3,3,7
2,4,4,4,2
2,5,4,4,7
2,4,4,4,4
8,4,8,4,6
</intermediate_grid>

Final solution:
</reasoning>
<answer>
7,3,3,3,7
2,4,4,4,2
2,5,3,3,7
2,4,4,4,4
8,D,8,D,6
</answer>"""

    test_solution_path = [
        """7,4,4,4,7
2,4,4,4,2
4,5,4,4,7
4,4,4,4,4
8,4,8,4,6""",
        """7,3,3,3,7
2,4,4,4,2
2,5,4,4,7
2,4,4,4,4
8,4,8,4,6""",
        """7,3,3,3,7
2,4,4,4,2
2,5,3,3,7
2,4,4,4,4
8,D,8,D,6"""
    ]

    extra = {"solution_path": test_solution_path}

    print("Testing intermediate grid extraction:")
    grids = extract_intermediate_grids(test_response)
    print(f"  Extracted {len(grids)} grids")

    print("\nTesting simple intermediate reward:")
    score = intermediate_grid_reward_simple(test_response, "", extra)
    print(f"  Score: {score:.4f}")

    print("\nTesting position-weighted intermediate reward:")
    score = intermediate_grid_reward_position(test_response, "", extra)
    print(f"  Score: {score:.4f}")

    print("\nTesting hybrid intermediate reward:")
    score = intermediate_grid_reward_hybrid(test_response, "", extra)
    print(f"  Score: {score:.4f}")
