#!/usr/bin/env python3
"""
Maze preprocessing for verl multi-turn "Ver@K retry with verifier feedback".

This script generates solved lattice mazes using the `maze-dataset` library and writes:
  <local_save_dir>/train.parquet
  <local_save_dir>/test.parquet

Output schema (per row):
  - data_source: str                         (e.g., "maze_dataset/dfs_5x5")
  - prompt: List[{"role": str, "content": str}]
  - ability: str                             ("maze")
  - reward_model: {"style": "rule", "ground_truth": str}
  - extra_info:
      - split: str                           ("train" | "test")
      - index: int
      - interaction_kwargs:
          - name: str                        (your interaction name)
          - query: str                       (prompt[-1]["content"])
          - ground_truth: str                (move string on ASCII grid, e.g., "RRRRDD")
          - max_attempts: int                (K)
          - maze:
              - grid_n: int
              - connection_list: list        (shape [2, grid_n, grid_n], ints)
              - start_pos: [r, c]
              - end_pos: [r, c]

Why interaction_kwargs is inside extra_info:
  verl.utils.dataset.rl_dataset.RLHFDataset.__getitem__ pulls:
      interaction_kwargs = row_dict["extra_info"]["interaction_kwargs"]
  and then sets row_dict["interaction_kwargs"] for AgentLoop/rollout. (So store it in extra_info!)

Usage:
  python examples/data_preprocess/maze_5x5_ver_k_retry.py \
    --grid_n 5 --n_train 100 --n_test 10 --k_max_attempts 4
"""

from __future__ import annotations

import argparse
import os
from pathlib import Path
from typing import Any, Dict, Optional

import datasets
import numpy as np

from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators


DEFAULT_ABILITY = "maze"


def _build_prompt_messages(
    maze_ascii: str,
    system_prompt: Optional[str] = None,
) -> list[dict]:
    """Construct HF-chat-style messages for the maze task."""
    maze_ascii = (maze_ascii or "").rstrip()
    if not maze_ascii:
        raise ValueError("Empty maze ASCII encountered.")

    user_content = "\n\n".join(
        [
            "You need to solve the following maze.",
            "'*' denotes the wall that you cannot walk through, '.' denotes available area that you can walk through. "
            "'S' denotes the starting point, 'E' denotes the destination.",
            "You need to start from the starting point and cross through the available area to reach the destination. "
            "There are four movement actions, including Left, Right, Up, Down.",
            "Use L to denote Left movement, R to denote Right movement, U to denote Up movement, and D to denote Down movement.",
            "You can analyze the maze to find the correct path, and you should write the final path in the "
            "<answer></answer>, e.g., <answer>LLRRDUL</answer>.",
            "## Maze",
            maze_ascii,
            "Now try to analyze the maze and put the final path in the <answer></answer>.",
        ]
    ).strip()

    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_content})
    return messages


def _solution_to_moves(solution: Any) -> str:
    coords = np.asarray(solution, dtype=int)
    if coords.ndim != 2 or coords.shape[1] != 2 or len(coords) < 2:
        raise ValueError(f"Invalid solution path shape: {coords.shape}")

    moves = []
    for (r0, c0), (r1, c1) in zip(coords[:-1], coords[1:]):
        dr = int(r1 - r0)
        dc = int(c1 - c0)
        if dr == -1 and dc == 0:
            moves.append("U")
        elif dr == 1 and dc == 0:
            moves.append("D")
        elif dr == 0 and dc == -1:
            moves.append("L")
        elif dr == 0 and dc == 1:
            moves.append("R")
        else:
            raise ValueError(f"Non-adjacent step in solution: {(r0, c0)} -> {(r1, c1)}")
    return "".join(moves)


def _expand_moves_to_ascii(moves: str) -> str:
    """Expand lattice moves to ASCII-grid moves (each edge spans 2 ASCII steps)."""
    return "".join(ch * 2 for ch in moves)


def _safe_pos(value: Any, fallback: np.ndarray) -> list[int]:
    if value is None:
        return [int(fallback[0]), int(fallback[1])]
    arr = np.asarray(value, dtype=int).reshape(-1)
    if arr.size != 2:
        return [int(fallback[0]), int(fallback[1])]
    return [int(arr[0]), int(arr[1])]


def _render_ascii_fallback(
    connection_list: np.ndarray,
    start_pos: np.ndarray,
    end_pos: np.ndarray,
) -> str:
    """Render a compact (2*grid_n+1)^2 maze using '*' walls, '.' passages, and 'S'/'E' cells."""
    grid_n = int(connection_list.shape[1])
    size = 2 * grid_n + 1
    grid = np.full((size, size), "*", dtype="<U1")

    for r in range(grid_n):
        for c in range(grid_n):
            center_r = 2 * r + 1
            center_c = 2 * c + 1
            grid[center_r, center_c] = "."
            if connection_list[1, r, c]:
                grid[center_r, center_c + 1] = "."
            if connection_list[0, r, c]:
                grid[center_r + 1, center_c] = "."

    # Mark start/end
    sr, sc = int(start_pos[0]), int(start_pos[1])
    er, ec = int(end_pos[0]), int(end_pos[1])
    grid[2 * sr + 1, 2 * sc + 1] = "S"
    grid[2 * er + 1, 2 * ec + 1] = "E"

    return "\n".join("".join(row) for row in grid)


def _maze_to_ascii(maze: Any, connection_list: np.ndarray, start_pos: np.ndarray, end_pos: np.ndarray) -> str:
    return _render_ascii_fallback(connection_list, start_pos, end_pos)


def _make_row(
    maze: Any,
    idx: int,
    *,
    split: str,
    data_source: str,
    interaction_name: str,
    k_max_attempts: int,
    system_prompt: Optional[str],
) -> Dict[str, Any]:
    raw_connection_list = getattr(maze, "connection_list", None)
    if raw_connection_list is None:
        raise ValueError("Maze missing connection_list.")
    connection_list = np.asarray(raw_connection_list)
    if connection_list.size == 0:
        raise ValueError("Maze has empty connection_list.")

    solution = getattr(maze, "solution", None)
    if solution is None:
        raise ValueError("Maze missing solution.")

    coords = np.asarray(solution, dtype=int)
    if coords.ndim != 2 or coords.shape[1] != 2:
        raise ValueError(f"Unexpected solution shape: {coords.shape}")

    start_pos = np.asarray(getattr(maze, "start_pos", coords[0]), dtype=int)
    end_pos = np.asarray(getattr(maze, "end_pos", coords[-1]), dtype=int)

    ground_truth = _expand_moves_to_ascii(_solution_to_moves(coords))
    ascii_maze = _maze_to_ascii(maze, connection_list, start_pos, end_pos)
    prompt = _build_prompt_messages(ascii_maze, system_prompt=system_prompt)

    row: Dict[str, Any] = {
        "data_source": data_source,
        "prompt": prompt,
        "ability": DEFAULT_ABILITY,
        "reward_model": {
            "style": "rule",
            "ground_truth": ground_truth,
        },
        "extra_info": {
            "split": split,
            "index": idx,
            "interaction_kwargs": {
                "name": interaction_name,
                "query": prompt[-1]["content"],
                "ground_truth": ground_truth,
                "max_attempts": int(k_max_attempts),
                "maze": {
                    "grid_n": int(connection_list.shape[1]),
                    "connection_list": np.asarray(connection_list, dtype=np.int8).tolist(),
                    "start_pos": _safe_pos(getattr(maze, "start_pos", None), coords[0]),
                    "end_pos": _safe_pos(getattr(maze, "end_pos", None), coords[-1]),
                },
            },
        },
    }
    return row


def _generate_mazes(
    *,
    grid_n: int,
    n_mazes: int,
    seed: int,
    name: str,
    local_base_path: Path,
) -> list[Any]:
    cfg = MazeDatasetConfig(
        name=name,
        grid_n=grid_n,
        n_mazes=n_mazes,
        maze_ctor=LatticeMazeGenerators.gen_dfs,
        maze_ctor_kwargs={},
        seed=seed,
    )
    dataset = MazeDataset.from_config(cfg, local_base_path=local_base_path, verbose=True)
    return list(dataset.mazes)


def main():
    parser = argparse.ArgumentParser(description="Preprocess maze dataset for verl Ver@K retry interaction training.")
    parser.add_argument("--local_save_dir", type=str, default=None,
                        help="Local output directory for train.parquet and test.parquet. "
                        "Defaults to ./data/maze_<grid_n>x<grid_n>_ver_k_retry_k<k_max_attempts>.")
    parser.add_argument("--hdfs_save_dir", type=str, default=None,
                        help="Optional HDFS output dir. If set, parquet files are copied there too.")
    parser.add_argument("--grid_n", type=int, default=5, help="Maze grid size (N for an NxN lattice).")
    parser.add_argument("--n_train", type=int, default=10000, help="Number of training mazes to generate.")
    parser.add_argument("--n_test", type=int, default=1000, help="Number of test mazes to generate.")
    parser.add_argument("--seed", type=int, default=1234, help="Base seed for deterministic generation.")
    parser.add_argument("--interaction_name", type=str, default="ver_k_retry",
                        help="Must match the 'name' your interaction registry/config expects.")
    parser.add_argument("--k_max_attempts", type=int, default=4,
                        help="K: maximum retries/turns for your Ver@K interaction agent.")
    parser.add_argument("--num_proc", type=int, default=1,
                        help="Reserved for future use; maze generation is currently single-threaded.")
    parser.add_argument("--system_prompt", type=str, default=None,
                        help="Optional system prompt. If not set, we only use a user message.")
    parser.add_argument("--keep_original_columns", action="store_true",
                        help="Unused for generated mazes; present for CLI parity.")
    args = parser.parse_args()

    if args.k_max_attempts <= 0:
        raise ValueError("--k_max_attempts must be >= 1")
    if args.grid_n <= 1:
        raise ValueError("--grid_n must be >= 2")
    if args.n_train <= 0 or args.n_test <= 0:
        raise ValueError("--n_train and --n_test must be >= 1")

    default_save_dir = f"./data/maze_{args.grid_n}x{args.grid_n}_ver_k_retry_k{args.k_max_attempts}"
    local_save_dir = os.path.expanduser(args.local_save_dir or default_save_dir)
    os.makedirs(local_save_dir, exist_ok=True)

    data_source = f"maze_dataset/dfs_{args.grid_n}x{args.grid_n}"

    raw_base = Path(local_save_dir) / "raw_maze_dataset"
    raw_base.mkdir(parents=True, exist_ok=True)

    train_mazes = _generate_mazes(
        grid_n=args.grid_n,
        n_mazes=args.n_train,
        seed=args.seed,
        name=f"maze_train_{args.grid_n}x{args.grid_n}_seed{args.seed}",
        local_base_path=raw_base,
    )
    test_mazes = _generate_mazes(
        grid_n=args.grid_n,
        n_mazes=args.n_test,
        seed=args.seed + 1,
        name=f"maze_test_{args.grid_n}x{args.grid_n}_seed{args.seed + 1}",
        local_base_path=raw_base,
    )

    train_rows = [
        _make_row(
            maze,
            idx,
            split="train",
            data_source=data_source,
            interaction_name=args.interaction_name,
            k_max_attempts=args.k_max_attempts,
            system_prompt=args.system_prompt,
        )
        for idx, maze in enumerate(train_mazes)
    ]
    test_rows = [
        _make_row(
            maze,
            idx,
            split="test",
            data_source=data_source,
            interaction_name=args.interaction_name,
            k_max_attempts=args.k_max_attempts,
            system_prompt=args.system_prompt,
        )
        for idx, maze in enumerate(test_mazes)
    ]

    train_out = datasets.Dataset.from_list(train_rows)
    test_out = datasets.Dataset.from_list(test_rows)

    train_path = os.path.join(local_save_dir, "train.parquet")
    test_path = os.path.join(local_save_dir, "test.parquet")

    train_out.to_parquet(train_path)
    test_out.to_parquet(test_path)

    print(f"[OK] Wrote: {train_path} ({len(train_out)} rows)")
    print(f"[OK] Wrote: {test_path} ({len(test_out)} rows)")

    if args.hdfs_save_dir:
        try:
            from verl.utils.hdfs_io import copy as hdfs_copy
            from verl.utils.hdfs_io import makedirs as hdfs_makedirs
        except Exception as e:
            raise RuntimeError(
                "You set --hdfs_save_dir but verl.utils.hdfs_io could not be imported. "
                "Run inside verl repo/env or remove --hdfs_save_dir."
            ) from e

        hdfs_makedirs(args.hdfs_save_dir)
        hdfs_copy(src=train_path, dst=os.path.join(args.hdfs_save_dir, "train.parquet"))
        hdfs_copy(src=test_path, dst=os.path.join(args.hdfs_save_dir, "test.parquet"))
        print(f"[OK] Copied parquet files to HDFS dir: {args.hdfs_save_dir}")

    # Quick sanity print + verifier check
    print("\nSample row (train[0]):")
    sample = train_out[0]
    print(sample)
    try:
        from examples.reward_fns.maze_path_reward import compute_score as maze_compute_score

        checks = min(3, len(train_out))
        for i in range(checks):
            row = train_out[i]
            gt = row["reward_model"]["ground_truth"]
            extra_info = row["extra_info"]
            tagged = f"<answer>{gt}</answer>"
            score = maze_compute_score(tagged, gt, data_source=row["data_source"], extra_info=extra_info)
            if score != 1.0:
                raise ValueError(f"Sanity check failed on row {i}: score={score}")
        print("[OK] Sanity check: compute_score(ground_truth) == 1.0")
    except Exception as e:
        raise RuntimeError(f"Sanity check failed: {e}") from e


if __name__ == "__main__":
    main()
