import random
import collections
import json
from tqdm import tqdm
import copy
from .template import PROMPT_TEMPLATE


def generate_full_sudoku() -> list:
    """
    Generate a complete valid Sudoku board.
    
    Returns:
        list: 9x9 Sudoku board
    """
    base = 3
    side = base * base

    def pattern(r: int, c: int) -> int:
        """Generate pattern for baseline valid solution."""
        return (base * (r % base) + r // base + c) % side

    def shuffle(s: list) -> list:
        """Randomly shuffle a list."""
        return random.sample(s, len(s))

    rBase = range(base)
    rows = [g * base + r for g in shuffle(rBase) for r in shuffle(rBase)]
    cols = [g * base + c for g in shuffle(rBase) for c in shuffle(rBase)]
    nums = shuffle(range(1, base * base + 1))

    # Produce board using randomized baseline pattern
    board = [[nums[pattern(r, c)] for c in cols] for r in rows]

    return board

def remove_numbers_from_board(board: list, holes: int = 40) -> list:
    """
    Remove numbers from a complete Sudoku board to create a puzzle.
    
    Args:
        board (list): Complete Sudoku board
        holes (int): Number of cells to empty
        
    Returns:
        list: Puzzle board with empty cells (represented by 0)
    """
    side = len(board)
    out = copy.deepcopy(board)
    squares = side * side
    empties = set(random.sample(range(squares), holes))
    
    for i in range(side):
        for j in range(side):
            if i * side + j in empties:
                out[i][j] = 0  # Use 0 to represent empty cells

    return out

def sudoku_to_string(board: list) -> str:
    """
    Convert Sudoku board to string representation.
    
    Args:
        board (list): 9x9 Sudoku board
        
    Returns:
        str: String representation of the board
    """
    sudoku_str = ""
    for row in board:
        sudoku_str += " ".join(str(num) if num != 0 else "." for num in row) + "\n"
    return sudoku_str

def generate(count: int = 100, difficulty: str = 'medium', language: str = 'en', split: str = "train"):
    """
    Generate Sudoku puzzles.
    
    Args:
        count (int): Number of puzzles to generate
        difficulty (str): Difficulty level ('easy', 'medium', or 'hard')
        language (str): Language of the puzzles
        split (str): Dataset split ('train' or 'test')
        
    Yields:
        dict: Generated puzzle with prompt, answer, and metadata
    """
    prompt_template = PROMPT_TEMPLATE
    
    dif_level = {"easy": [10, 20], "medium": [20, 35], "hard": [35, 50]}
    lo, hi = dif_level[difficulty][0], dif_level[difficulty][1]

    for i in tqdm(range(count)):
        board = generate_full_sudoku()
        holes = random.randint(lo, hi)
        puzzle = remove_numbers_from_board(board, holes)
        puzzle_str = sudoku_to_string(puzzle)
        yield {
            "prompt": prompt_template.format(question=puzzle_str),
            "answer": board,
            "task_name": "sudoku",
            "ability": "logic_puzzle",
            "language": language,
            "meta": json.dumps({
                "id": f"sudoku_{difficulty}_{i}",
                "question": puzzle,
                "holes": holes,
                "answer": board,
                "rationale": "",
                "split": split,
                "type": "sudoku_puzzle",
                "source_url": "auto-generated",
                "dataset_name": "sudoku",
                "difficulty_level": difficulty,
                "language": language,
            }),
        }

def save_to_jsonl(output_file: str, count: int, lange: str = 'en'):
    """
    Save generated puzzles to a JSONL file.
    
    Args:
        output_file (str): Output file path
        count (int): Total number of puzzles to generate
        lange (str): Language for the puzzles
    """
    with open(output_file, 'w', encoding='utf-8') as f:
        for item in generate(count // 3, 'easy', lange):
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'medium', lange):
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'hard', lange):
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

