import lmql
import pickle
import networkx
import asyncio
import lmql.algorithms as la
import json

@lmql.query
async def explore(graph_json):
    '''lmql
    argmax
        graph = json.loads(graph_json)
        rooms, hallways = graph['rooms'], graph['hallways']
        rooms = {int(k): v for k, v in rooms.items()}
        hallways = {int(k): tuple(v) for k, v in hallways.items()}
        node = 0

        node = list(rooms.keys())[0]
        "System: You are exploring a dungeon. Your goal is to find the exit.\n"

        steps = 0
        max_steps = 10
        
        while rooms[node] != 'Exit':
            name = rooms[node]
            neighbours = hallways[node]
            "System: You are in room {node} '{name}'. "
            "You can go to {neighbours}. "
            "Where do you want to go?\n"
            "You:[ACTION]\n"
            next_node = int(ACTION.strip())
            if next_node not in neighbours:
                "System: {next_node} is not a valid neighboring room of '{name}'. Valid rooms are {neighbours}. \n"
            else:
                node = next_node
            steps += 1

            if steps > max_steps:
                "System: You have taken too many steps. You lose.\n"
                return ("failure", steps, context.num_calls, context.prompt)
        return ("success", steps, context.num_calls, context.prompt)
    from
        "openai/text-davinci-003"
    where
        ACTION in [" 0", " 1", " 2", " 3", " 4", " 5", " 6", " 7", " 8", " 9"]
    '''

@lmql.query
async def explore_beam(graph_json):
    '''lmql
    beam_var(return_first=True)
        graph = json.loads(graph_json)
        rooms, hallways = graph['rooms'], graph['hallways']
        rooms = {int(k): v for k, v in rooms.items()}
        hallways = {int(k): tuple(v) for k, v in hallways.items()}
        node = 0

        node = list(rooms.keys())[0]
        "System: You are exploring a dungeon. Your goal is to find the exit.\n"

        steps = 0
        max_steps = 10
        
        while rooms[node] != 'Exit':
            name = rooms[node]
            neighbours = hallways[node]
            "System: You are in room {node} '{name}'. "
            "You can go to {neighbours}. "
            "Where do you want to go?\n"
            "You:[ACTION]\n"
            next_node = int(ACTION.strip())
            if next_node not in neighbours:
                "System: {next_node} is not a valid neighboring room of '{name}'. Valid rooms are {neighbours}. \n"
            else:
                node = next_node
            steps += 1

            if steps > max_steps:
                "System: You have taken too many steps. You lose.\n"
                return ("failure", steps, context.num_calls, context.prompt)
        return ("success", steps, context.num_calls, context.prompt)
    from
        "openai/text-davinci-003"
    where
        ACTION in [" 0", " 1", " 2", " 3", " 4", " 5", " 6", " 7", " 8", " 9"]
    '''


@lmql.query
async def explore_var(graph_json):
    '''lmql
    var(n=4, b=2, subdecoder="beam")
        graph = json.loads(graph_json)
        rooms, hallways = graph['rooms'], graph['hallways']
        rooms = {int(k): v for k, v in rooms.items()}
        hallways = {int(k): tuple(v) for k, v in hallways.items()}
        node = 0

        node = list(rooms.keys())[0]
        "System: You are exploring a dungeon. Your goal is to find the exit.\n"

        steps = 0
        max_steps = 10
        
        while rooms[node] != 'Exit':
            name = rooms[node]
            neighbours = hallways[node]
            "System: You are in room {node} '{name}'. "
            "You can go to {neighbours}. "
            "Where do you want to go?\n"
            "You:[ACTION]\n"
            next_node = int(ACTION.strip())
            if next_node not in neighbours:
                "System: {next_node} is not a valid neighboring room of '{name}'. Valid rooms are {neighbours}. \n"
            else:
                node = next_node
            steps += 1

            if steps > max_steps:
                "System: You have taken too many steps. You lose.\n"
                return ("failure", steps, context.num_calls, context.prompt)
        return ("success", steps, context.num_calls, context.prompt)
    from
        "openai/text-davinci-003"
    where
        ACTION in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    '''

def convert_graph(G):
    import json
    rooms = {n: G.nodes[n]['name'] for n in G.nodes}
    hallways = {n: tuple(list(G.neighbors(n))) for n in G.nodes}

    object = {
        "rooms": rooms,
        "hallways": hallways
    }

    return json.dumps(object)

async def main():
    la.caching(True)
    
    graphs = pickle.load(open('dungeons.pkl', 'rb'))
    for g in graphs:
        print(len(g.nodes))
    graphs = [convert_graph(g) for g in graphs]

    
    # argmax
    result = await la.map(explore, graphs, progress=True)
    average_steps = [r[1] for r in result if r[0] == 'success']
    success_rate = f"{sum([r[0] == 'success' for r in result])}/{len(result)}"
    average_calls = sum([r[2] + 1 for r in result])/len(result)
    print("argmax", success_rate, " average steps:", sum(average_steps)/len(average_steps), " average calls:", average_calls)

    # beam_var
    result = await la.map(explore_beam, graphs, progress=True)
    unpacked_results = []
    for r in result:
        if type(r[0]) is not str:
            unpacked_results.append(r[0])
        else:
            unpacked_results.append(r)
    result = unpacked_results
    for r in result:
        print(r[-1])
    
    average_steps = [r[1] for r in result if r[0] == 'success']
    success_rate = f"{sum([r[0] == 'success' for r in result])}/{len(result)}"
    average_calls = sum([r[2] + 1 for r in result])/len(result)
    print("beam_var", success_rate, " average steps:", sum(average_steps)/len(average_steps), " average calls:", average_calls)

    # var
    result = await la.map(explore_var, graphs, progress=True)
    unpacked_results = []
    for r in result:
        if type(r[0]) is not str:
            unpacked_results.append(r[0])
        else:
            unpacked_results.append(r)
    result = unpacked_results
    
    average_steps = [r[1] for r in result if r[0] == 'success']
    success_rate = f"{sum([r[0] == 'success' for r in result])}/{len(result)}"
    average_calls = sum([r[2] + 1 for r in result])/len(result)
    print("var", success_rate, " average steps:", sum(average_steps)/len(average_steps), " average calls:", average_calls)

asyncio.run(main())