"""Generate and classify illegal Tic Tac Toe boards (configurable starting player).

Legal generation already exists in `generate_dataset_v0.py`. This script enumerates
all 3^9 boards (values 0=empty, 1=P1 (X), 2=P2 (O)) and filters those that are
NOT reachable under specified Tic Tac Toe rules:
    - You can choose the start player: X, O, or allow BOTH (union of both games).
    - Players alternate; game stops immediately once someone wins.

We also assign reason codes explaining why a board is illegal. Categories:

  COUNT_DIFF_GT1        -> |#X - #O| > 1
  START_COUNT_INVALID   -> Piece counts cannot arise from any allowed start player
  DOUBLE_WIN            -> Both players show winning lines simultaneously
  WIN_COUNT_MISMATCH    -> Winner present but move counts inconsistent with who moved last
  CONTINUED_AFTER_WIN   -> Board passes superficial count constraints but implies
                                                     continued play after an earlier win
  GENERIC_UNREACHABLE   -> Fallback if none of the above matched (should be rare)

Approach for reachability check (standard algorithm similar to LeetCode 794):
  1. Basic count constraints.
  2. Win constraints (cannot both win; winner must have correct counts).
  3. Recursive backward simulation: try removing one plausible last move so that
     the previous board is also valid. Memoized for efficiency.

Because total boards = 19,683, exhaustive enumeration is trivial.

CLI Usage:
  python generate_illegal_boards.py \
      --sample-size 500 \
      --output /tmp/illegal_ttt.json

Outputs a JSON list with objects:
  {
    "board": [..9 ints..],
    "ascii_board": "X . O\n...",
    "count_x": int,
    "count_o": int,
    "x_win": bool,
    "o_win": bool,
    "reasons": [list of reason codes]
  }

If --sample-size omitted, writes the full set.
"""

from __future__ import annotations
import argparse
import itertools
import json
from functools import lru_cache
from typing import List, Tuple, Iterable, Optional

WIN_LINES: Tuple[Tuple[int,int,int], ...] = (
    # Rows
    (0,1,2), (3,4,5), (6,7,8),
    # Cols
    (0,3,6), (1,4,7), (2,5,8),
    # Diagonals
    (0,4,8), (2,4,6)
)


def is_win(board: List[int], player: int) -> bool:
    return any(board[a] == board[b] == board[c] == player for a,b,c in WIN_LINES)


def check_winner(board: List[int]) -> int:
    if is_win(board, 1):
        return 1
    if is_win(board, 2):
        return 2
    return 0


@lru_cache(None)
def _is_reachable_given_start(board_tuple: Tuple[int, ...], start_player: int) -> bool:
    """Reachability for a fixed start player (1 or 2)."""
    board = list(board_tuple)
    count_x = board.count(1)
    count_o = board.count(2)

    # Map counts to generic roles
    c_start = count_x if start_player == 1 else count_o
    c_other = count_o if start_player == 1 else count_x

    # Count constraints: start player never behind; difference in {0,1}
    if c_other > c_start:
        return False
    if c_start - c_other > 1:
        return False

    x_win = is_win(board, 1)
    o_win = is_win(board, 2)
    if x_win and o_win:
        return False

    # Winner count relation
    if x_win:
        # X just moved
        if start_player == 1:
            if count_x != count_o + 1:
                return False
        else:  # start_player == 2; X is other
            if count_o != count_x:  # other wins -> counts equal
                return False
    if o_win:
        if start_player == 2:
            if count_o != count_x + 1:
                return False
        else:  # start_player == 1; O is other
            if count_x != count_o:  # other wins -> counts equal
                return False

    # Empty base
    if count_x == 0 and count_o == 0:
        return True

    # No winner: determine last mover
    if not x_win and not o_win:
        last_player = start_player if c_start > c_other else (2 if start_player == 1 else 1)
        for i, v in enumerate(board):
            if v == last_player:
                board[i] = 0
                if _is_reachable_given_start(tuple(board), start_player):
                    board[i] = last_player
                    return True
                board[i] = last_player
        return False
    else:
        winner = 1 if x_win else 2
        for i, v in enumerate(board):
            if v == winner:
                board[i] = 0
                if not is_win(board, winner) and _is_reachable_given_start(tuple(board), start_player):
                    board[i] = winner
                    return True
                board[i] = winner
        return False


@lru_cache(None)
def is_reachable_board(board_tuple: Tuple[int, ...], start_mode: str) -> bool:
    """Reachability wrapper supporting start_mode in {x,o,both}."""
    if start_mode == 'x':
        return _is_reachable_given_start(board_tuple, 1)
    if start_mode == 'o':
        return _is_reachable_given_start(board_tuple, 2)
    if start_mode == 'both':
        return (_is_reachable_given_start(board_tuple, 1) or
                _is_reachable_given_start(board_tuple, 2))
    raise ValueError(f"Invalid start_mode: {start_mode}")


def board_to_ascii(board: List[int]) -> str:
    symbols = {0:'.', 1:'X', 2:'O'}
    rows = []
    for r in range(3):
        rows.append(" ".join(symbols[board[r*3 + c]] for c in range(3)))
    return "\n".join(rows)


VALID_REASON_CODES = [
    "COUNT_DIFF_GT1",
    "DOUBLE_WIN",
    "WIN_COUNT_MISMATCH",
    "CONTINUED_AFTER_WIN",
    "START_COUNT_INVALID",
    "GENERIC_UNREACHABLE"
]


def classify_illegal(board: List[int], start_mode: str) -> List[str]:
    reasons = []
    count_x = board.count(1)
    count_o = board.count(2)
    x_win = is_win(board, 1)
    o_win = is_win(board, 2)

    # Raw count sanity independent of start_mode (|diff| > 1 always impossible)
    if abs(count_x - count_o) > 1:
        reasons.append("COUNT_DIFF_GT1")

    if x_win and o_win:
        reasons.append("DOUBLE_WIN")

    # Winner count mismatch (for each possible start). Only add if mismatch under every allowed start.
    allowed_starts: Iterable[str]
    if start_mode == 'both':
        allowed_starts = ('x', 'o')
    else:
        allowed_starts = (start_mode,)

    def winner_consistent(start: str) -> bool:
        if not (x_win or o_win):
            return True
        if x_win:
            if start == 'x':
                return count_x == count_o + 1
            else:  # start == 'o'
                return count_o == count_x
        if o_win:
            if start == 'o':
                return count_o == count_x + 1
            else:
                return count_x == count_o
        return False

    if (x_win or o_win) and not any(winner_consistent(s) for s in allowed_starts):
        reasons.append("WIN_COUNT_MISMATCH")

    if not reasons:
        if x_win or o_win:
            reasons.append("CONTINUED_AFTER_WIN")
        else:
            reasons.append("START_COUNT_INVALID")
    return reasons


def generate_illegal_boards(start_mode: str = 'x',
                            sample_size: Optional[int] = None,
                            include_reasons: Optional[List[str]] = None,
                            reason_mode: str = 'any',
                            balanced: bool = True):
    all_boards = itertools.product([0,1,2], repeat=9)
    illegal = []
    include_set = set(include_reasons) if include_reasons else None
    if include_set:
        unknown = include_set - set(VALID_REASON_CODES)
        if unknown:
            raise ValueError(f"Unknown reason codes: {sorted(unknown)}")
    for b in all_boards:
        if not is_reachable_board(b, start_mode):
            board_list = list(b)
            reasons = classify_illegal(board_list, start_mode)
            if include_set:
                rs = set(reasons)
                if reason_mode == 'any' and not (rs & include_set):
                    continue
                if reason_mode == 'all' and not include_set.issubset(rs):
                    continue
            illegal.append({
                "board": board_list,
                "ascii_board": board_to_ascii(board_list),
                "count_x": board_list.count(1),
                "count_o": board_list.count(2),
                "x_win": is_win(board_list, 1),
                "o_win": is_win(board_list, 2),
                "reasons": reasons,
                "start_mode": start_mode
            })
    if sample_size is not None and sample_size < len(illegal):
        if balanced:
            # Balanced sampling across primary reason categories
            # Determine categories to balance
            if include_set:
                categories = list(include_set)
            else:
                # collect first reason of each entry
                categories = sorted({entry["reasons"][0] for entry in illegal})

            # Map board to primary category (first matching category in its reasons list)
            cat_groups = {c: [] for c in categories}
            for entry in illegal:
                # determine primary category
                primary = None
                for r in entry["reasons"]:
                    if r in cat_groups:
                        primary = r
                        break
                if primary is None:
                    # fallback: put into first category
                    primary = categories[0]
                cat_groups[primary].append(entry)

            k = len(categories)
            base_target = sample_size // k if k else sample_size
            leftover = sample_size - base_target * k

            # Sort categories by available size descending for leftover allocation
            ordered = sorted(categories, key=lambda c: len(cat_groups[c]), reverse=True)
            targets = {c: min(base_target, len(cat_groups[c])) for c in categories}
            for c in ordered:
                if leftover <= 0:
                    break
                if targets[c] < len(cat_groups[c]):
                    targets[c] += 1
                    leftover -= 1

            selected = []
            for c in categories:
                group = cat_groups[c]
                take = targets[c]
                if take > 0:
                    selected.extend(group[:take])

            # If still short (some categories exhausted), fill from remaining boards
            if len(selected) < sample_size:
                remaining_needed = sample_size - len(selected)
                used_ids = {id(x) for x in selected}
                extras = [e for e in illegal if id(e) not in used_ids]
                selected.extend(extras[:remaining_needed])
            illegal = selected[:sample_size]
        else:
            # Simple slice (deterministic)
            illegal = illegal[:sample_size]
    return illegal


def main():
    parser = argparse.ArgumentParser(description="Generate illegal Tic Tac Toe boards with reason codes.")
    parser.add_argument("--output", type=str, required=True, help="Path to write JSON output")
    parser.add_argument("--sample-size", type=int, default=None, help="Optional cap on number of illegal boards to save")
    parser.add_argument("--pretty", action="store_true", help="Pretty-print JSON")
    parser.add_argument("--start-player", choices=["x","o","both"], default="x", help="Who can start (x, o, both)")
    parser.add_argument("--reasons", nargs="+", help="Filter: only include boards whose reasons match (ANY by default unless --reason-mode all). Valid: " + ", ".join(VALID_REASON_CODES))
    parser.add_argument("--reason-mode", choices=["any","all"], default="any", help="Filtering logic when --reasons provided")
    parser.add_argument("--no-balance", dest="balance", action="store_false", help="Disable balanced sampling across reason codes (default on)")
    args = parser.parse_args()

    illegal = generate_illegal_boards(start_mode=args.start_player,
                                      sample_size=args.sample_size,
                                      include_reasons=args.reasons,
                                      reason_mode=args.reason_mode,
                                      balanced=args.balance)
    print(f"Illegal boards generated: {len(illegal)}")
    with open(args.output, "w") as f:
        if args.pretty:
            json.dump(illegal, f, indent=2)
        else:
            json.dump(illegal, f)
    # Quick reason distribution
    dist = {}
    for entry in illegal:
        for r in entry["reasons"]:
            dist[r] = dist.get(r, 0) + 1
    print("Reason distribution:")
    for k, v in sorted(dist.items(), key=lambda x: (-x[1], x[0])):
        print(f"  {k}: {v}")


if __name__ == "__main__":
    main()
