import numpy as np
from data_process.sdk.solution import (
    solve_sudoku,
    is_valid_board,
    find_best_empty_cell,
    get_candidates
)


class Board:
    """Sudoku board class for puzzle initialization, solving and validation"""
    def __init__(self, board_init=None):
        self.board = np.zeros((9, 9), dtype=int)
        if board_init is not None:
            self.set_board(board_init)

    def set_board(self, board_input):
        """Set original puzzle with support for str, list or np.ndarray input"""
        if isinstance(board_input, str):
            from data_process.sdk.solution import str_to_board
            self.board = str_to_board(board_input)
        elif isinstance(board_input, list):
            self.board = np.array(board_input, dtype=int)
        elif isinstance(board_input, np.ndarray):
            self.board = board_input.copy()
        else:
            raise TypeError("Input must be str, list or np.ndarray")
        assert self.board.shape == (9, 9), "Board must be 9x9"

    def solve(self):
        """Solve the Sudoku puzzle (update self.board with valid solution if exists)"""
        if not self.input_is_valid():
            print("Invalid board, cannot solve")
            return False
        
        original_board = self.board.copy()
        solution = solve_sudoku(original_board)
        
        if not isinstance(solution, (list, np.ndarray)):
            print("No solution: Solver returned empty")
            return False
        
        solution_arr = np.array(solution, dtype=int)
        if solution_arr.shape != (9, 9) or np.count_nonzero(solution_arr == 0) > 0:
            print("No solution: Incomplete solution or incorrect shape")
            return False
        
        non_zero_mask = (original_board != 0)
        if not np.all(original_board[non_zero_mask] == solution_arr[non_zero_mask]):
            print("No solution: Solution conflicts with given numbers")
            return False
        
        self.board = solution_arr
        return True

    def input_is_valid(self):
        """Check if the original puzzle is valid (no conflicts in given numbers)"""
        return is_valid_board(self.board.copy())

    def has_unique_solution(self):
        """Check if the puzzle has a unique solution (for strict validation)"""
        temp_board = Board(self.board.copy())
        solutions = []
        
        def backtrack_count(board_arr):
            empty = find_best_empty_cell(board_arr)
            if not empty:
                solutions.append(board_arr.copy())
                return len(solutions) <= 1
            row, col = empty
            for num in get_candidates(board_arr, (row, col)):
                board_arr[row][col] = num
                if not backtrack_count(board_arr):
                    return False
                board_arr[row][col] = 0
            return True
        
        backtrack_count(temp_board.board)
        return len(solutions) == 1