#!/usr/bin/env python3
"""
Generate Sudoku puzzles and solutions based on generate_sudoku.py logic, export as images and JSON.

Supports different Sudoku grid sizes (e.g., 4x4, 9x9, 16x16, etc.).

Output structure:
- outputs/
  - question/
    - 001.png, 002.png, 003.png, ...
  - answer/
    - 001.png, 002.png, 003.png, ...
  - sudoku/
    - 001.json, 002.json, 003.json, ...

Usage:
  python3 sudoku/export_sudoku_images.py --count "20,80" --outdir outputs --givens "10,40" --grid-size "4,9"

Note: grid-size, givens, and count lists must have the same length and correspond one-to-one.
For example: --grid-size "4,9" --givens "10,40" --count "20,80" means:
- 4x4 Sudoku with 10 given numbers, generate 20
- 9x9 Sudoku with 40 given numbers, generate 80

Optional arguments:
  --digits-font Font path for rendering digits (optional).
  --size Image size (square side length), default 1024.
  --line-width Grid line width, default 2.
  --grid-size Sudoku grid size list, comma-separated, e.g., "4,9".
  --givens Number of given digits list, comma-separated, e.g., "10,40".
  --count Generation count list, comma-separated, e.g., "20,80".
  --seed Random seed.
"""

import argparse
import json
import os
import random
from typing import List, Tuple, Dict

from PIL import Image, ImageDraw, ImageFont


def construct_puzzle_solution(size: int = 9) -> List[List[int]]:
    """Generate a complete size x size Sudoku solution. Adapted from generate_sudoku.py for Python3."""
    while True:
        try:
            puzzle = [[0] * size for _ in range(size)]
            rows = [set(range(1, size + 1)) for _ in range(size)]
            columns = [set(range(1, size + 1)) for _ in range(size)]
            squares = [set(range(1, size + 1)) for _ in range(size)]

            for i in range(size):
                for j in range(size):
                    choices = rows[i].intersection(columns[j]).intersection(
                        squares[(i // int(size ** 0.5)) * int(size ** 0.5) + (j // int(size ** 0.5))]
                    )
                    choice = random.choice(list(choices))

                    puzzle[i][j] = choice
                    rows[i].discard(choice)
                    columns[j].discard(choice)
                    squares[(i // int(size ** 0.5)) * int(size ** 0.5) + (j // int(size ** 0.5))].discard(choice)

            return puzzle
        except IndexError:
            # If stuck, restart
            pass


def _can_be_a(puz: List[List[int]], i: int, j: int, c: int, size: int = 9) -> bool:
    v = puz[c // size][c % size]
    if puz[i][j] == v:
        return True
    if puz[i][j] in range(1, size + 1):
        return False

    for m in range(size):
        if not (m == c // size and j == c % size) and puz[m][j] == v:
            return False
        if not (i == c // size and m == c % size) and puz[i][m] == v:
            return False
        if not (
            ((i // int(size ** 0.5)) * int(size ** 0.5) + (m // int(size ** 0.5)) == c // size)
            and ((j // int(size ** 0.5)) * int(size ** 0.5) + (m % int(size ** 0.5)) == c % size)
        ) and puz[(i // int(size ** 0.5)) * int(size ** 0.5) + (m // int(size ** 0.5))][(j // int(size ** 0.5)) * int(size ** 0.5) + (m % int(size ** 0.5))] == v:
            return False

    return True


def pluck(puzzle: List[List[int]], n: int = 0, size: int = 9) -> Tuple[List[List[int]], int]:
    """Remove cells from complete solution while ensuring deducibility, return puzzle and number of remaining known cells."""
    total_cells = size * size
    cells = set(range(total_cells))
    cellsleft = cells.copy()
    while len(cells) > n and len(cellsleft):
        cell = random.choice(list(cellsleft))
        cellsleft.discard(cell)

        row = col = square = False

        for i in range(size):
            if i != cell // size:
                if _can_be_a(puzzle, i, cell % size, cell, size):
                    row = True
            if i != cell % size:
                if _can_be_a(puzzle, cell // size, i, cell, size):
                    col = True
            if not (
                (((cell // size) // int(size ** 0.5)) * int(size ** 0.5) + (i // int(size ** 0.5)) == cell // size)
                and (((cell // size) % int(size ** 0.5)) * int(size ** 0.5) + (i % int(size ** 0.5)) == cell % size)
            ):
                if _can_be_a(
                    puzzle,
                    ((cell // size) // int(size ** 0.5)) * int(size ** 0.5) + (i // int(size ** 0.5)),
                    ((cell // size) % int(size ** 0.5)) * int(size ** 0.5) + (i % int(size ** 0.5)),
                    cell,
                    size,
                ):
                    square = True

        if row and col and square:
            continue
        else:
            puzzle[cell // size][cell % size] = 0
            cells.discard(cell)

    return puzzle, len(cells)


def generate_one(n_givens: int = 23, max_iter: int = 10, size: int = 9) -> Tuple[List[List[int]], List[List[int]]]:
    """Generate a puzzle and its solution."""
    solution = construct_puzzle_solution(size)
    for _ in range(max_iter):
        puz = [row[:] for row in solution]
        quiz, givens = pluck(puz, n=n_givens, size=size)
        if givens <= n_givens:
            return quiz, solution
    return quiz, solution


def grid_to_lines(grid: List[List[int]]) -> List[str]:
    return [" ".join(str(x or 0) for x in row) for row in grid]


def grid_to_string(grid: List[List[int]]) -> str:
    return "".join(str(x or 0) for row in grid for x in row)


def render_grid(
    grid: List[List[int]],
    size: int = 512,
    line_width: int = 4,
    font_path: str = "",
) -> Image.Image:
    """Render size x size grid as an image."""
    grid_size = len(grid)
    img = Image.new("RGB", (size, size), (255, 255, 255))
    draw = ImageDraw.Draw(img)

    # Grid and margin
    margin = int(size * 0.06)
    board = size - margin * 2
    cell = board / grid_size

    # Lines
    for i in range(grid_size + 1):
        lw = line_width * (2 if i % int(grid_size ** 0.5) == 0 else 1)
        x = margin + int(round(i * cell))
        y0 = margin
        y1 = margin + board
        draw.line([(x, y0), (x, y1)], fill=(0, 0, 0), width=lw)
        y = margin + int(round(i * cell))
        x0 = margin
        x1 = margin + board
        draw.line([(x0, y), (x1, y)], fill=(0, 0, 0), width=lw)

    # Font
    font = None
    if font_path and os.path.isfile(font_path):
        try:
            font = ImageFont.truetype(font_path, size=int(cell * 0.7))
        except Exception:
            font = None
    if font is None:
        try:
            font = ImageFont.truetype("DejaVuSansMono.ttf", size=int(cell * 0.7))
        except Exception:
            font = ImageFont.load_default()

    # Numbers
    for r in range(grid_size):
        for c in range(grid_size):
            v = grid[r][c]
            if not v:
                continue
            text = str(v)
            cx = margin + int(round(c * cell + cell / 2))
            cy = margin + int(round(r * cell + cell / 2))
            # Text size
            bbox = draw.textbbox((0, 0), text, font=font)
            tw = bbox[2] - bbox[0]
            th = bbox[3] - bbox[1]
            draw.text((cx - tw // 2, cy - th // 2), text, font=font, fill=(10, 10, 10))

    return img


def ensure_dirs(base_out: str) -> Dict[str, str]:
    """Create output directories."""
    qd = os.path.join(base_out, "question")
    ad = os.path.join(base_out, "answer")
    sd = os.path.join(base_out, "sudoku")
    os.makedirs(qd, exist_ok=True)
    os.makedirs(ad, exist_ok=True)
    os.makedirs(sd, exist_ok=True)
    return {"question": qd, "answer": ad, "sudoku": sd}


def save_one(index: int, quiz: List[List[int]], solution: List[List[int]],
             dirs: Dict[str, str], size: int, line_width: int, font_path: str, givens: int, grid_size: int) -> None:
    stem = f"{index:03d}"
    q_img = render_grid(quiz, size=size, line_width=line_width, font_path=font_path)
    a_img = render_grid(solution, size=size, line_width=line_width, font_path=font_path)
    q_img.save(os.path.join(dirs["question"], f"{stem}.png"))
    a_img.save(os.path.join(dirs["answer"], f"{stem}.png"))

    data = {
        "index": stem,
        "grid_size": grid_size,
        "givens": givens,
        "question": {
            "grid": quiz,
            "lines": grid_to_lines(quiz),
            "flat": grid_to_string(quiz),
        },
        "answer": {
            "grid": solution,
            "lines": grid_to_lines(solution),
            "flat": grid_to_string(solution),
        },
    }
    with open(os.path.join(dirs["sudoku"], f"{stem}.json"), "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def main():
    parser = argparse.ArgumentParser(description="Export Sudoku puzzles/solutions as images and JSON")
    parser.add_argument("--count", type=str, default="20, 80", help="Generation count list, comma-separated, corresponds one-to-one with grid-size")
    parser.add_argument("--outdir", type=str, default="/home/hlihg/HDD/lhx/sudoku/outputs", help="Output root directory")
    parser.add_argument("--digits-font", dest="digits_font", type=str, default="", help="Font path for digits")
    parser.add_argument("--size", type=int, default=1024, help="Image size (square side length)")
    parser.add_argument("--line-width", dest="line_width", type=int, default=2, help="Grid line width")
    parser.add_argument("--grid-size", type=str, default="4, 9", help="Sudoku grid size list, comma-separated, e.g., '4,9,16'")
    parser.add_argument("--givens", type=str, default="10, 40", help="Number of given digits list, comma-separated, e.g., '8,20' (smaller = harder)")
    parser.add_argument("--seed", type=int, default=None, help="Random seed")
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)

    # Parse count, givens, and grid-size lists
    count_list = [int(x.strip()) for x in args.count.split(",")]
    givens_list = [int(x.strip()) for x in args.givens.split(",")]
    grid_size_list = [int(x.strip()) for x in args.grid_size.split(",")]
    
    # Ensure all three lists have the same length
    if not (len(grid_size_list) == len(givens_list) == len(count_list)):
        raise ValueError(f"List lengths do not match: grid-size ({len(grid_size_list)}), givens ({len(givens_list)}), count ({len(count_list)}). They should correspond one-to-one.")
    
    # Create unified output directories
    dirs = ensure_dirs(args.outdir)
    
    # Generate images for each (grid-size, givens, count) combination
    file_index = 1
    for grid_size, givens, count in zip(grid_size_list, givens_list, count_list):
        print(f"Generating Sudoku puzzles with grid-size={grid_size}x{grid_size}, givens={givens}, total {count}...")
        for i in range(1, count + 1):
            quiz, solution = generate_one(n_givens=givens, max_iter=20, size=grid_size)
            save_one(
                index=file_index,
                quiz=quiz,
                solution=solution,
                dirs=dirs,
                size=args.size,
                line_width=args.line_width,
                font_path=args.digits_font,
                givens=givens,
                grid_size=grid_size,
            )
            file_index += 1
        print(f"Completed grid-size={grid_size}x{grid_size}, givens={givens}, generated {count} puzzles")


if __name__ == "__main__":
    main()


