import random
import collections
import json
from tqdm import tqdm
from .template import PROMPT_TEMPLATE_ZH, PROMPT_TEMPLATE

def diff_j(board: list, K: int) -> int:
    """
    Calculate the number of inversions in the board.
    
    Args:
        board (list): The board state
        K (int): Size of the board (K x K)
        
    Returns:
        int: Number of inversions
    """
    # Flatten the board into a 1D list
    flat_board = [tile for row in board for tile in row if tile != 0]
    inversions = 0

    # Count inversions
    for i in range(len(flat_board)):
        for j in range(i + 1, len(flat_board)):
            if flat_board[i] > flat_board[j]:
                inversions += 1

    return inversions

def generate_puzzle(K: int, diff: list) -> tuple:
    """
    Generate a random KxK puzzle with specified difficulty.
    
    Args:
        K (int): Size of the board (K x K)
        diff (list): Difficulty range [min_inversions, max_inversions]
        
    Returns:
        tuple: (board state, number of inversions)
    """
    # Create a list of numbers from 1 to K^2
    tiles = list(range(1, K**2+1))
    random.shuffle(tiles)

    # Convert list to KxK board
    board = [tiles[i:i + K] for i in range(0, K**2, K)]
    inv = diff_j(board, K)

    while not diff[0] <= inv <= diff[1]:
        random.shuffle(tiles)
        board = [tiles[i:i + K] for i in range(0, K**2, K)]
        inv = diff_j(board, K)

    return [board, inv]

def board_to_string(board: list) -> str:
    """
    Convert board state to string representation.
    
    Args:
        board (list): The board state
        
    Returns:
        str: String representation of the board
    """
    board_str = '\n'.join(' '.join(str(tile).rjust(2, ' ') for tile in row) for row in board)
    return board_str

def generate(count: int = 100, difficulty: str = 'medium', language: str = 'en', split: str = "train"):
    """
    Generate 16-puzzle problems.
    
    Args:
        count (int): Number of problems to generate
        difficulty (str): Difficulty level ('easy', 'medium', or 'hard')
        language (str): Language of the problems ('en' or 'zh')
        split (str): Dataset split ('train' or 'test')
        
    Yields:
        dict: Generated problem with prompt, answer, and metadata
    """
    prompt_template = PROMPT_TEMPLATE if language == 'en' else PROMPT_TEMPLATE_ZH
    dif_level = {"easy": [0, 45], "medium": [46, 59], "hard": [60, 100]}
    diff = dif_level[difficulty]
    K = 4
    
    for i in tqdm(range(count)):
        board, inv = generate_puzzle(K, diff)
        board_str = board_to_string(board)
        yield {
            "prompt": prompt_template.format(question=board_str),
            "answer": board,
            "task_name": "sixteen_puzzle",
            "ability": "logic_puzzle",
            "language": language,
            "meta": {
                "id": f"16-puzzle_{difficulty}{i}",
                "question": board,
                "answer": board,
                "inversion": inv,
                "rationale": "",
                "split": split,
                "type": "sequential_puzzle",
                "source_url": "auto-generated",
                "dataset_name": "sixteen_puzzle",
                "difficulty_level": difficulty,
                "language": language,
            }
        }

def save_to_jsonl(of1: str, of2: str, count: int, lange: str = 'en'):
    """
    Save generated problems to JSONL files.
    
    Args:
        of1 (str): Output file path for the main data
        of2 (str): Output file path for the metadata
        count (int): Total number of problems to generate
        lange (str): Language for the problems ('en' or 'zh')
    """
    with open(of1, 'w', encoding='utf-8') as f1, open(of2, 'w', encoding='utf-8') as f2:
        for item in generate(count // 3, 'easy', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'medium', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'hard', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')

