import random
from .template import PROMPT_TEMPLATE
from itertools import permutations, product
import json
from tqdm import tqdm

def generate_numbers(num_nums: int) -> list:
    """
    Generate a list of random numbers between 1 and 13.
    
    Args:
        num_nums (int): Number of random numbers to generate
        
    Returns:
        list: List of random numbers
    """
    return [random.randint(1, 13) for _ in range(num_nums)]

operations = ['+', '-', '*', '/']
EPSILON = 1e-6  # Error tolerance for floating point comparison

def save_eval(exp: str) -> float:
    """
    Safely evaluate an expression, handling division by zero.
    
    Args:
        exp (str): The expression to evaluate
        
    Returns:
        float: The result of evaluation, or None if division by zero occurs
    """
    try:
        return eval(exp)
    except ZeroDivisionError:
        return None

def eval_4(n: list, o: list) -> str:
    """
    Evaluate expressions for 4 numbers to get 24.
    
    Args:
        n (list): List of 4 numbers as strings
        o (list): List of 3 operators
        
    Returns:
        str: Valid expression that equals 24, or None if no solution exists
    """
    # (n1 op1 n2) op2 (n3 op3 n4)
    exp = '(' + n[0] + o[0] + n[1] + ')' + o[1] + '(' + n[2] + o[2] + n[3] + ')'
    if abs(save_eval(exp) - 24) < EPSILON:
        return exp
        
    # ((n1 op1 n2) op2 n3) op3 n4
    exp = '(' + '(' + n[0] + o[0] + n[1] + ')' + o[1] + n[2] + ')' + o[2] + n[3]
    if abs(save_eval(exp) - 24) < EPSILON:
        return exp
        
    # n1 op1 ((n2 op2 n3) op3 n4)
    exp = n[0] + o[0] + '(' + '(' + n[1] + o[1] + n[2] + ')' + o[2] + n[3] + ')'
    if abs(save_eval(exp) - 24) < EPSILON:
        return exp
        
    # (n1 op1 (n2 op2 n3)) op3 n4
    exp = '(' + n[0] + o[0] + '(' + n[1] + o[1] + n[2] + ')' + ')' + o[2] + n[3]
    if abs(save_eval(exp) - 24) < EPSILON:
        return exp
        
    # n1 op1 (n2 op2 (n3 op3 n4))
    exp = n[0] + o[0] + '(' + n[1] + o[1] + '(' + n[2] + o[2] + n[3] + ')' + ')'
    if abs(save_eval(exp) - 24) < EPSILON:
        return exp
        
    return None

def eval_5(n: list, o: list) -> str:
    """
    Evaluate expressions for 5 numbers to get 24.
    
    Args:
        n (list): List of 5 numbers as strings
        o (list): List of 4 operators
        
    Returns:
        str: Valid expression that equals 24, or None if no solution exists
    """
    # Eval4((n1 o1 n2), n3, n4, n5)
    tmp = save_eval(n[0] + o[0] + n[1])
    if tmp is not None:
        res = eval_4([str(tmp), n[2], n[3], n[4]], [o[1], o[2], o[3]])
        if res:
            return n[0] + o[0] + n[1] + '=' + str(tmp) + ", " + res
            
    # Eval4(n1, (n2 o2 n3), n4, n5)
    tmp = save_eval(n[1] + o[1] + n[2])
    if tmp is not None:
        res = eval_4([n[0], str(tmp), n[3], n[4]], [o[0], o[2], o[3]])
        if res:
            return n[1] + o[1] + n[2] + '=' + str(tmp) + ", " + res
            
    # Eval4(n1, n2, (n3 o3 n4), n5)
    tmp = save_eval(n[2] + o[2] + n[3])
    if tmp is not None:
        res = eval_4([n[0], n[1], str(tmp), n[4]], [o[0], o[1], o[3]])
        if res:
            return n[2] + o[2] + n[3] + '=' + str(tmp) + ", " + res
            
    # Eval4(n1, n2, n3, (n4 o4 n5))
    tmp = save_eval(n[3] + o[3] + n[4])
    if tmp is not None:
        res = eval_4([n[0], n[1], n[2], str(tmp)], [o[0], o[1], o[2]])
        if res:
            return n[3] + o[3] + n[4] + '=' + str(tmp) + ", " + res
            
    return None

def eval_6(n: list, o: list) -> str:
    """
    Evaluate expressions for 6 numbers to get 24.
    
    Args:
        n (list): List of 6 numbers as strings
        o (list): List of 5 operators
        
    Returns:
        str: Valid expression that equals 24, or None if no solution exists
    """
    # Eval5((n1 o1 n2), n3, n4, n5, n6)
    tmp = save_eval(n[0] + o[0] + n[1])
    if tmp is not None:
        res = eval_5([str(tmp), n[2], n[3], n[4], n[5]], [o[1], o[2], o[3], o[4]])
        if res:
            return n[0] + o[0] + n[1] + '=' + str(tmp) + ", " + res
            
    # Eval5(n1, (n2 o2 n3), n4, n5, n6)
    tmp = save_eval(n[1] + o[1] + n[2])
    if tmp is not None:
        res = eval_5([n[0], str(tmp), n[3], n[4], n[5]], [o[0], o[2], o[3], o[4]])
        if res:
            return n[1] + o[1] + n[2] + '=' + str(tmp) + ", " + res
            
    # Eval5(n1, n2, (n3 o3 n4), n5, n6)
    tmp = save_eval(n[2] + o[2] + n[3])
    if tmp is not None:
        res = eval_5([n[0], n[1], str(tmp), n[4], n[5]], [o[0], o[1], o[3], o[4]])
        if res:
            return n[2] + o[2] + n[3] + '=' + str(tmp) + ", " + res
            
    # Eval5(n1, n2, n3, (n4 o4 n5), n6)
    tmp = save_eval(n[3] + o[3] + n[4])
    if tmp is not None:
        res = eval_5([n[0], n[1], n[2], str(tmp), n[5]], [o[0], o[1], o[2], o[4]])
        if res:
            return n[3] + o[3] + n[4] + '=' + str(tmp) + ", " + res
            
    # Eval5(n1, n2, n3, n4, (n5 o5 n6))
    tmp = save_eval(n[4] + o[4] + n[5])
    if tmp is not None:
        res = eval_5([n[0], n[1], n[2], n[3], str(tmp)], [o[0], o[1], o[2], o[3]])
        if res:
            return n[4] + o[4] + n[5] + '=' + str(tmp) + ", " + res
            
    return None

def can_form_24(nums: list, lang: str = 'en') -> str:
    """
    Check if the given numbers can form 24 using basic arithmetic operations.
    
    Args:
        nums (list): List of numbers to check
        lang (str): Language for the response
        
    Returns:
        str: The solution expression or a message indicating no solution exists
    """
    answer_cue = 'The answer is: '
    refuse_cue = 'cannot form 24'
    len_nums = len(nums)
    
    for n in permutations(nums):
        for o in product(operations, repeat=(len_nums - 1)):
            n = list(map(str, n))
            try:
                if len_nums == 6:
                    res_o = eval_6(n, o)
                elif len_nums == 5:
                    res_o = eval_5(n, o)
                elif len_nums == 4:
                    res_o = eval_4(n, o)
                if res_o:
                    return f"{answer_cue}{res_o} = 24"
            except:
                continue
    return refuse_cue

def generate(count: int = 100, difficulty: str = 'medium', language: str = 'en', split: str = "train"):
    """
    Generate game24 puzzles with specified parameters.
    
    Args:
        count (int): Number of puzzles to generate
        difficulty (str): Difficulty level ('easy', 'medium', or 'hard')
        language (str): Language for the puzzles
        split (str): Dataset split ('train' or 'test')
        
    Yields:
        dict: Generated puzzle with prompt, answer, and metadata
    """
    dic = {'easy': 4, 'medium': 5, 'hard': 6}
    num_nums = dic[difficulty]
    prompt_template = PROMPT_TEMPLATE
    
    for i in tqdm(range(count)):
        numbers = generate_numbers(num_nums)
        numbers_str = ",".join(map(str, numbers))
        answer = can_form_24(numbers, language)
        
        yield {
            "prompt": prompt_template.format(question=numbers_str),
            "answer": answer,
            "task_name": "game24",
            "ability": "logic_puzzle",
            "language": language,
            "meta": json.dumps({
                "id": f"game24_{difficulty}_{i}",
                "question": numbers,
                "answer": answer,
                "rationale": "",
                "split": split,
                "type": "code_puzzle",
                "source_url": "auto-generated",
                "dataset_name": "game24",
                "difficulty_level": difficulty,
                "language": language,
            }),
        }

def save_to_jsonl(of1: str, of2: str, count: int, lange: str = 'en'):
    """
    Save generated puzzles to JSONL files.
    
    Args:
        of1 (str): Output file path for the main data
        of2 (str): Output file path for the metadata
        count (int): Total number of puzzles to generate
        lange (str): Language for the puzzles
    """
    with open(of1, 'w', encoding='utf-8') as f1, open(of2, 'w', encoding='utf-8') as f2:
        for item in generate(count // 3, 'easy', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'medium', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')
        for item in generate(count // 3, 'hard', lange):
            f1.write(json.dumps(item, ensure_ascii=False) + '\n')
            f2.write(json.dumps(item["meta"], ensure_ascii=False) + '\n')

