from typing import Dict, Set, List, Literal, Sequence
import copy
import random
import numpy as np
from numpy.typing import NDArray


type Grid = Dict[int, Set[int]]
type Board = NDArray[np.int8]


def empty_board() -> Board:
    return np.zeros((9, 9), dtype=np.int8)


def get_block(board: Board, i: int, j: int):
    if not ((0 <= i < 3) and (0 <= j < 3)):
        raise IndexError
    start_row = i * 3
    start_col = j * 3
    return board[start_row: start_row + 3, start_col: start_col + 3]


def get_block_from_cell(board: Board, row: int, col: int):
    return get_block(board, row//3, col//3)


def get_square_indices(row: int, col: int) -> List[int]:
    """
    Given row and col of a cell, find all adjacent cells in its square.
    Return a list of zero-based indices
    """
    sq_row, sq_col = row // 3, col // 3
    sq_start_ndx = (sq_row * 9 * 3) + (sq_col * 3)
    return [sq_start_ndx + c for c in [0, 1, 2, 9, 10, 11, 18, 19, 20]]


def is_valid_fill_number(board: Board, row: int, col: int, num: int):
    return not (
        num in board[row, :] or 
        num in board[:, col] or
        num in get_block_from_cell(board, row, col)
    )


def conflict_flags(board: Board):
    """check conflicts in a board"""
    flags = np.zeros_like(board, dtype=np.bool_)
    for row in range(9):
        for col in range(9):
            num = board[row, col]
            if 1 <= num <= 9:  # determined
                board[row, col] = 0  # temporarily set to uncertain
                flags[row, col] = not is_valid_fill_number(board, row, col, num)
                board[row, col] = num  # restore value
            else:
                continue
    return flags


adjacents: Dict[int, List[List[int]]] = {}
for row in range(9):
    for col in range(9):
        index = row * 9 + col
        cells_in_row = [x for x in range(row * 9, row * 9 + 9) if x != index]
        cells_in_col = [x for x in range(col, 81, 9) if x != index]
        cells_in_sq = get_square_indices(row, col)
        if index in cells_in_sq:
            cells_in_sq.remove(index)
        adjacents[index] = [cells_in_row, cells_in_col, cells_in_sq]


class InvalidSudokuError(Exception):
    """
    Raised when input sudoku is invalid.
    Invalid cases are: 1) Length of string representing sudoku is not 81 chars
    2) Initially assigned values (also called clues) do not have conflicting
    cases (e.g: two 3s in a row).
    """
    pass


def _char2num(ch: str, invalid: int | None = None) -> int:
    try:
        out = int(ch)
        if 0 <= out <= 9:
            return out
        else:
            raise ValueError(f"Can not convert '{ch}' to an integer in [0, 9].")
    except ValueError:
        if invalid is None:
            raise
        else:
            return invalid


def str_to_board(seq: Sequence[str], invalid: int | None = None) -> Board:
    """Convert a string to a Sudoku Board, with characters in `blanks` considered as empty cells."""

    if len(seq) != 81:
        raise ValueError("The input must have exactly 81 cells.")
    
    return np.array(
        [
            [_char2num(seq[row * 9 + col], invalid) for col in range(9)]
            for row in range(9)
        ],
        dtype=np.int8
    )


def repr_board(board: Board, rowsep: str = '\n', colsep: str = ' ') -> str:
    lines: list[str] = []
    for row in range(9):
        lines.append(colsep.join(str(board[row, col]) for col in range(9)))
    return rowsep.join(lines)


def board_to_grid(board: Board, fill_possibilities: bool = False) -> Grid:
    assert board.shape == (9, 9)
    possibilities = set(range(1, 10)) if fill_possibilities else set()
    return {row * 9 + col: ({int(board[row, col])} if board[row, col] else possibilities)
            for row in range(9) for col in range(9)}


def grid_to_board(grid: Grid):
    assert len(grid) == 81
    out = empty_board()
    for i, cell in grid.items():
        row = i // 9
        col = i % 9
        if len(cell) == 1:
            out[row, col] = next(iter(cell))
    return out



def _validnums(x: NDArray[np.int8]) -> str:
    s = set(str(num) for num in x.flat if num > 0)
    if s:
        return ''.join(sorted(s))
    else:
        return '0'


def board_subsets(board: Board):
    return {
        "rows": [_validnums(board[row, :]) for row in range(9)],
        "cols": [_validnums(board[:, col]) for col in range(9)],
        "blocks": [_validnums(get_block(board, i, j)) for i in range(3) for j in range(3)],
    }


def eliminate(grid: Grid, cell: int) -> tuple[set[int], bool]:
    possibilities = grid[cell]
    if len(possibilities) == 1:
        return possibilities, False
    
    # Remove solved cells from possibilities
    cell_adjacents = adjacents[cell][0] + adjacents[cell][1] + adjacents[cell][2]
    solved_vals = [grid[v] for v in cell_adjacents if len(grid[v]) == 1]
    if solved_vals:
        new_poss = possibilities - set.union(*solved_vals)
        if len(new_poss) == 0:
            raise InvalidSudokuError
        return new_poss, True
   
    return possibilities, False


# def unique(grid: Grid):
#     unique_exist = False
#     action = None

#     for cell in range(81):
#         if len(grid[cell]) > 1:
#             possibilities = grid[cell]
#             for uniqueness in [0, 1, 2]:
#                 adjacent_values = set.union(*[grid[v] for v in adjacents[cell][uniqueness]])
#                 unique_values = possibilities - adjacent_values

#                 if len(unique_values) == 1:  # 这一动作成功了
#                     unique_exist = True
#                     grid[cell] = unique_values
#                     action = {
#                         'type': 'unique',
#                         'posistion': cell,
#                         'value': unique_values
#                     }
#                 # Did our reduction lead to an invalid state?
#                 # if {1, 2, 3, 4, 5, 6, 7, 8, 9} != adjacent_values | possibilities:
#                 #     return False
#         else:
#             continue
#
#     return unique_exist, grid, action


def reduce_grid(grid: Grid) -> tuple[bool, Grid]:
    new_grid = copy.deepcopy(grid)
    reduced = False

    for cell in range(81):
        if len(grid[cell]) > 1:
            new_poss, changed = eliminate(grid, cell)
            new_grid[cell] = new_poss
            reduced = reduced or changed
    
    return reduced, new_grid


def get_reduce_action(grid: Grid) -> dict:

    action = {
        'type': 'reduce',
        'rows': {i: set() for i in range(9)},
        'cols': {i: set() for i in range(9)},
        'subgrids': {i: set() for i in range(9)},
        
    }
    # check confirmed values in each row, col and subgrid
    solved_cells = [v for v in range(81) if len(grid[v]) == 1]
    for cell in solved_cells:
        action['rows'][cell // 9] = action['rows'][cell // 9].union(grid[cell])
        action['cols'][cell % 9] = action['cols'][cell % 9].union(grid[cell])
        action['subgrids'][(cell // 9) // 3 * 3 + (cell % 9) // 3] = action['subgrids'][(cell // 9) // 3 * 3 + (cell % 9) // 3].union(grid[cell])

    return action
