import random
from datasets import Dataset

def create_solved_sudoku(seed=None):
    """Create a valid solved 4x4 Sudoku grid"""
    if seed is not None:
        random.seed(seed)
    
    grid = [
        [1, 2, 3, 4],
        [3, 4, 1, 2],
        [2, 1, 4, 3],
        [4, 3, 2, 1]
    ]
    
    # Shuffle rows within same 2x2 blocks
    if random.choice([True, False]):
        grid[0], grid[1] = grid[1], grid[0]
    if random.choice([True, False]):
        grid[2], grid[3] = grid[3], grid[2]
    
    # Shuffle columns within same 2x2 blocks
    if random.choice([True, False]):
        for row in grid:
            row[0], row[1] = row[1], row[0]
    if random.choice([True, False]):
        for row in grid:
            row[2], row[3] = row[3], row[2]
    
    # Shuffle the numbers (1-4) randomly
    nums = [1, 2, 3, 4]
    random.shuffle(nums)
    mapping = {i+1: nums[i] for i in range(4)}
    
    for i in range(4):
        for j in range(4):
            grid[i][j] = mapping[grid[i][j]]
    
    return grid


def is_valid(grid, row, col, num):
    """Check if num can be placed in grid[row][col]"""
    # Row and column
    if num in grid[row]:
        return False
    if num in [grid[i][col] for i in range(4)]:
        return False
    
    # 2x2 block
    start_row, start_col = 2*(row//2), 2*(col//2)
    for i in range(start_row, start_row+2):
        for j in range(start_col, start_col+2):
            if grid[i][j] == num:
                return False
    return True


def count_solutions(grid):
    """Count number of solutions using backtracking"""
    for i in range(4):
        for j in range(4):
            if grid[i][j] == 0:
                count = 0
                for num in range(1, 5):
                    if is_valid(grid, i, j, num):
                        grid[i][j] = num
                        count += count_solutions(grid)
                        grid[i][j] = 0
                        if count > 1:   # Stop early if more than 1 solution
                            return count
                return count
    return 1  # Found a full valid solution


def grid_to_visual_string(grid):
    """Convert grid to visual string representation"""
    lines = []
    for row in grid:
        line = ""
        for num in row:
            if num == 0:
                line += ". "
            else:
                line += str(num) + " "
        lines.append(line.rstrip())  # Remove trailing space
    return "\n".join(lines)


def grid_to_16digit_answer(grid):
    """Convert solved grid to 16-digit integer"""
    digits = ""
    for row in grid:
        for num in row:
            digits += str(num)
    return int(digits)


def create_puzzle(num_prefilled=8, max_attempts=1000, seed=None):
    """
    Create a 4x4 Sudoku puzzle with unique solution
    Returns: (puzzle_string, answer_integer)
    """
    if seed is not None:
        random.seed(seed)
    
    if num_prefilled < 1 or num_prefilled > 16:
        print("Number of prefilled cells must be between 1 and 16")
        return None, None
    
    for attempt in range(max_attempts):
        # Use attempt number as additional seed variation
        current_seed = seed + attempt if seed is not None else None
        solution = create_solved_sudoku(current_seed)
        puzzle = [row[:] for row in solution]
        
        # Randomly remove cells
        positions = [(i, j) for i in range(4) for j in range(4)]
        random.shuffle(positions)
        cells_to_remove = 16 - num_prefilled
        
        for k in range(cells_to_remove):
            i, j = positions[k]
            puzzle[i][j] = 0
        
        # Check uniqueness
        if count_solutions([row[:] for row in puzzle]) == 1:
            puzzle_string = grid_to_visual_string(puzzle)
            answer_integer = grid_to_16digit_answer(solution)
            return puzzle_string, answer_integer
    
    print("Failed to generate unique puzzle after many attempts.")
    return None, None


def print_sudoku(grid, title="Sudoku Grid"):
    """Pretty print a 4x4 Sudoku grid"""
    print(f"\n{title}:")
    print("+" + "-"*7 + "+" + "-"*7 + "+")
    
    for i, row in enumerate(grid):
        line = "| "
        for j, num in enumerate(row):
            if num == 0:
                line += ". "
            else:
                line += str(num) + " "
            if j == 1:
                line += "| "
        line += "|"
        print(line)
        
        if i == 1:
            print("+" + "-"*7 + "+" + "-"*7 + "+")
    
    print("+" + "-"*7 + "+" + "-"*7 + "+")

def generate_sudoku_dataset(num_examples=10000, num_prefilled=8, seed=42):
    """Generate a deterministic dataset of Sudoku puzzles"""
    dataset = []
    seen_puzzles = set()  # Track unique puzzles
    i = 0
    attempts = 0
    max_attempts = num_examples * 100  # Prevent infinite loop
    
    while len(dataset) < num_examples and attempts < max_attempts:
        puzzle_seed = seed + i if seed is not None else None
        puzzle, solution = create_puzzle(num_prefilled, seed=puzzle_seed)
        
        if puzzle is not None and solution is not None:
            if puzzle not in seen_puzzles:  # Check for uniqueness
                seen_puzzles.add(puzzle)
                dataset.append({
                    'question': puzzle,
                    'answer': solution
                })
        i += 1
        attempts += 1
    
    if len(dataset) < num_examples:
        print(f"Warning: Could only generate {len(dataset)} unique puzzles")
    
    dataset = Dataset.from_list(dataset)
    return dataset
