"""
Reward function for ARC-AGI-1 tasks.

This module provides reward functions that evaluate whether agents correctly
solve ARC-AGI-1 tasks. It uses reasoning_gym's scoring when available, otherwise
falls back to exact match. For grid-based tasks, it also computes structured
differences for richer feedback.
"""

import ast
import json
import re
from collections import Counter
from typing import Any, Dict, List, Optional

from rllm import Action
from rllm.rewards.reward_types import RewardOutput

Grid = List[List[int]]


def extract_solution(solution_str: str) -> str | None:
    """Extract the solution from the solution string."""
    # Remove everything before the first "Assistant:" if present
    if "Assistant:" in solution_str:
        solution_str = solution_str.split("Assistant:", 1)[1]
    elif "<|im_start|>assistant" in solution_str:
        solution_str = solution_str.split("<|im_start|>assistant", 1)[1]

    # Look for answer pattern in the entire string
    answer_pattern = r"<answer>(.*?)</answer>"
    match = re.finditer(answer_pattern, solution_str, re.IGNORECASE | re.DOTALL)
    matches = list(match)
    if matches:
        final_answer = matches[-1].group(1).strip()
        return final_answer

    # If no answer tags, return the stripped string
    return solution_str.strip() if solution_str.strip() else None


def _is_grid(obj: Any) -> bool:
    """Check if obj is a rectangular list-of-lists-of-ints grid."""
    if not isinstance(obj, (list, tuple)) or len(obj) == 0:
        return False
    row_len = None
    for row in obj:
        if not isinstance(row, (list, tuple)) or len(row) == 0:
            return False
        if row_len is None:
            row_len = len(row)
        elif len(row) != row_len:
            return False
        for x in row:
            if not isinstance(x, int):
                return False
    return True


def _to_grid(obj: Any) -> Optional[Grid]:
    """
    Try to coerce obj into a grid (list-of-lists-of-ints).

    Handles:
      - already-a-grid (list-of-lists-of-ints)
      - JSON string representation
      - Python literal string representation (via ast.literal_eval)
      - fenced code blocks like ```json ... ```
    Returns None if parsing fails.
    """
    # Already a grid
    if _is_grid(obj):
        # Ensure it's a plain list-of-lists-of-ints
        return [list(map(int, row)) for row in obj]

    if not isinstance(obj, str):
        return None

    s = obj.strip()

    # Strip ```...``` fences if present
    if s.startswith("```"):
        # remove leading ```lang\n or ```\n
        first_newline = s.find("\n")
        if first_newline != -1:
            s = s[first_newline + 1 :]
        if s.endswith("```"):
            s = s[:-3].strip()

    # Try JSON
    try:
        parsed = json.loads(s)
        if _is_grid(parsed):
            return [list(map(int, row)) for row in parsed]
    except Exception:
        pass

    # Try Python literal
    try:
        parsed = ast.literal_eval(s)
        if _is_grid(parsed):
            return [list(map(int, row)) for row in parsed]
    except Exception:
        pass

    return None


def _grid_stats(grid: Grid) -> Dict[str, Any]:
    h = len(grid)
    w = len(grid[0]) if h > 0 else 0
    colors = [cell for row in grid for cell in row]
    color_counts = Counter(colors)
    return {
        "shape": (h, w),
        "color_counts": dict(color_counts),
        "num_cells": h * w,
    }


def _grid_diff(gt: Grid, pred: Grid) -> Dict[str, Any]:
    """
    Compute simple, programmatic differences between ground-truth and predicted grids.

    Returns a dict like:
      {
          "shape_match": bool,
          "gt_shape": (H, W),
          "pred_shape": (H, W),
          "missing_colors": [int],
          "extra_colors": [int],
          "num_mismatched_cells": int,
          "sample_mismatches": [(row, col), ...],
      }
    """
    stats_gt = _grid_stats(gt)
    stats_pred = _grid_stats(pred)

    shape_match = stats_gt["shape"] == stats_pred["shape"]

    gt_colors = set(stats_gt["color_counts"])
    pred_colors = set(stats_pred["color_counts"])
    missing_colors = list(gt_colors - pred_colors)
    extra_colors = list(pred_colors - gt_colors)

    mismatches: List[tuple[int, int]] = []
    if shape_match:
        H, W = stats_gt["shape"]
        for i in range(H):
            for j in range(W):
                if gt[i][j] != pred[i][j]:
                    mismatches.append((i, j))

    return {
        "shape_match": shape_match,
        "gt_shape": stats_gt["shape"],
        "pred_shape": stats_pred["shape"],
        "missing_colors": missing_colors,
        "extra_colors": extra_colors,
        "num_mismatched_cells": len(mismatches),
        "sample_mismatches": mismatches[:20],
    }


def arc_agi_1_reward_fn(task_info: dict, action: str | Action, eval: bool = False) -> RewardOutput:
    """
    A specialized reward function for arc-agi-1 tasks.

    Evaluates whether the agent correctly solves the ARC-AGI-1 task.
    Uses reasoning_gym's scoring if available, otherwise falls back to exact match.

    Additionally, when the ground truth and prediction can be parsed as
    ARC-style grids, it will compute a structured diff (metadata["diff"])
    that the SolverWorkflow can turn into richer feedback.

    Args:
        task_info: Dictionary containing question, ground_truth, metadata
        action: The agent's solution string (or Action)
        eval: If True, use strict evaluation (exact match only). For training, use False to allow partial credit.

    Returns:
        RewardOutput with reward, is_correct, and metadata.
    """
    try:
        if isinstance(action, Action):
            action = action.action

        # Extract solution from action (may contain <answer> tags)
        extracted = extract_solution(action)
        if extracted is None:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "No solution extracted from answer tags",
                },
            )

        # Get ground truth
        ground_truth = task_info.get("ground_truth") or task_info.get("answer", "")
        if not ground_truth:
            return RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={
                    "validation": "invalid_solution",
                    "error": "No ground truth provided",
                },
            )

        # Try to use reasoning_gym's scoring if available
        try:
            import reasoning_gym

            # Get metadata
            meta = task_info.get("metadata", {})
            entry = {
                "question": task_info.get("question", ""),
                "answer": ground_truth,
                "metadata": meta,
            }

            # Try to use reasoning_gym's score_answer function
            if hasattr(reasoning_gym, "score_answer"):
                score = reasoning_gym.score_answer(extracted, entry)
            else:
                # Fall back to exact match
                score = 1.0 if extracted.strip() == str(ground_truth).strip() else 0.0
        except (ImportError, AttributeError):
            # Fall back to exact match if reasoning_gym is not available
            score = 1.0 if extracted.strip() == str(ground_truth).strip() else 0.0

        # Determine correctness
        if eval:
            # Strict evaluation: require exact match
            is_correct = (score == 1.0)
        else:
            # Training mode: allow partial credit (score >= 0.9)
            is_correct = (score >= 0.9)

        # Validation label: distinguish partial vs incorrect
        if is_correct:
            validation = "correct_solution"
        elif score > 0.0:
            validation = "partial_solution"
        else:
            validation = "incorrect_solution"

        metadata: Dict[str, Any] = {
            "validation": validation,
            "score": float(score),
            "extracted_solution": extracted[:200],  # Truncate for logging
        }

        # Try to compute a structured grid-level diff if both
        # ground truth and prediction look like grids.
        try:
            gt_grid = _to_grid(ground_truth)
            pred_grid = _to_grid(extracted)
            if gt_grid is not None and pred_grid is not None:
                metadata["diff"] = _grid_diff(gt_grid, pred_grid)
        except Exception as diff_err:
            # Don't crash if diff computation fails; just log the error.
            metadata["diff_error"] = str(diff_err)

        return RewardOutput(
            reward=float(score),
            is_correct=is_correct,
            metadata=metadata,
        )

    except Exception as e:
        return RewardOutput(
            reward=0.0,
            is_correct=False,
            metadata={
                "validation": "invalid_solution",
                "error": str(e),
            },
        )


def arc_agi_1_reward_fn_eval(task_info: dict, action):
    """Wrapper for arc_agi_1_reward_fn with eval=True."""
    return arc_agi_1_reward_fn(task_info, action, eval=True)

