import lmql
import lmql.algorithms as la
import asyncio

# "openai/text-curie-001" 

def make_board(s):
    board = []
    for r in s.split("\n"):
        board += [[None if c == " " else int(c) for c in r]]
    return board

def render_board(s):
    board = make_board(s)
    s = ""
    for r in board:
        for c in r:
            if c is None:
                s += "  "
            else:
                s += str(c) + " "
        s += "\n"
    return s

@lmql.query
async def sudoku(board):
    '''lmql   
    argmax(openai_chunksize=1)  
        """
        Solve the following 3x3 sudoku:
        """

        m = make_board(board)

        for i in range(len(m)):
            for j in range(len(m[i])):
                if m[i][j] is None:
                    " [N]"
                    try:
                        m[i][j] = int(N)
                    except:
                        return ("invalid", context.prompt)
                else:
                    v = m[i][j]
                    " {v}"
            "\n"
        s = set()
        for r in m:
            for v in r:
                s.add(v)
        if len(s) == 9:
            return ("success", context.prompt)
        else:
            return ("invalid", context.prompt)
    from
        "openai/text-curie-001" 
    where
        len(TOKENS(N)) == 1 and N in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]
    '''
@lmql.query
async def sudoku_reorder(board):
    '''lmql   
    argmax   
        """
        Solve the following 3x3 sudoku:
        
        {render_board(board)}
        
        Now fill in the remaining spots
        [RESULT]
        """
        s = set()
        lines = [l for l in RESULT.split("\n") if l.strip() != ""]
        for l in lines[:3]:
            for n in l:
                if len(n.strip()) > 0:
                    try:
                        s.add(int(n))
                    except:
                        return ("invalid", context.prompt)
        if len(s) == 9:
            return ("success", context.prompt + str(sorted(list(s))))
        else:
            return ("invalid", context.prompt + str(sorted(list(s))))
    from
        "openai/text-curie-001" 
    '''


@lmql.query
async def sudoku_bsseq(board):
    '''lmql   
    bsseq(openai_chunksize=1)
        """
        Solve the following 3x3 sudoku:
        """

        m = make_board(board)

        for i in range(len(m)):
            for j in range(len(m[i])):
                if m[i][j] is None:
                    " [N]"
                    try:
                        m[i][j] = int(N)
                    except:
                        return ("invalid", context.prompt)
                else:
                    v = m[i][j]
                    " {v}"
            "\n"
        s = set()
        for r in m:
            for v in r:
                s.add(v)
        if len(s) == 9:
            return ("success", context.prompt)
        else:
            return ("invalid", context.prompt)
    from
        "openai/text-curie-001" 
    where
        len(TOKENS(N)) == 1 and N in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]
    '''

@lmql.query
async def sudoku_beam(board):
    '''lmql   
    beam_var(openai_chunksize=1, n=5) 
        """
        Solve the following 3x3 sudoku:
        """

        m = make_board(board)

        for i in range(len(m)):
            for j in range(len(m[i])):
                if m[i][j] is None:
                    "[N]"
                    try:
                        m[i][j] = int(N)
                    except:
                        return ("invalid", context.prompt)
                else:
                    v = m[i][j]
                    " {v}"
            "\n"
        s = set()
        for r in m:
            for v in r:
                s.add(v)
        if len(s) == 9:
            return ("success", context.prompt)
        else:
            return ("invalid", context.prompt)
    from
        "openai/text-curie-001" 
    where
        len(TOKENS(N)) == 1 and N in [" 1", " 2", " 3", " 4", " 5", " 6", " 7", " 8", " 9"]
    '''


@lmql.query
async def sudoku_var(board):
    '''lmql   
    var(n=3, b=3, openai_chunksize=1, subdecoder="beam")
        """
        Solve the following 3x3 sudoku:
        """

        m = make_board(board)

        for i in range(len(m)):
            for j in range(len(m[i])):
                if m[i][j] is None:
                    "[N]"
                    try:
                        m[i][j] = int(N)
                    except:
                        return ("invalid", context.prompt)
                else:
                    v = m[i][j]
                    " {v}"
            "\n"
        s = set()
        for r in m:
            for v in r:
                s.add(v)
        if len(s) == 9:
            return ("success", context.prompt)
        else:
            return ("invalid", context.prompt)
    from
        "openai/text-curie-001" 
    where
        len(TOKENS(N)) == 1 and N in [" 1", " 2", " 3", " 4", " 5", " 6", " 7", " 8", " 9"]
    '''


async def main():
    boards = [
       "9 6\n  8\n31 ",
       "9  \n218\n453",
       "38 \n   \n52 ",
       "142\n7 5\n  3",
       " 1 \n  2\n 8 ",
       "18 \n439\n 7 ",
       " 46\n123\n7  ",
       " 92\n17 \n 5 ",
       "571\n8 6\n243",
       "183\n95 \n27 "
    ]

    la.caching(True)

    results = await la.map(sudoku_reorder, boards)
    # for r in results:
    #     print(r)
    print("argmax-reorder: ", len([r for r in results if r[0] == "success"]), "/", len(results))

    # results = await la.map(sudoku_bsseq, boards)
    # print("bsseq: ", len([r for r in results if r[0] == "success"]), "/", len(results))
    
    results = await la.map(sudoku, boards)
    print("argmax: ", len([r for r in results if r[0] == "success"]), "/", len(results))

    results = await la.map(sudoku_beam, boards)
    # for r in results:
    #     print(r)
    print("beam_var: ", len([r for r in results if r[0][0] == "success"]), "/", len(results))

    results = await la.map(sudoku_var, boards)
    # for r in results:
    #     print(r)
    print("var: ", len([r for r in results if r[0][0] == "success"]), "/", len(results))

asyncio.run(main())