import os
import datetime
import pickle
from tqdm import tqdm
import random
import numpy as np
from collections import deque

from graph.make_graph import (
    build_graph_prompt_fixed_length,
    build_graph_prompt_verify,
    get_graph_property,
)
from graph.prompts import reorder_incident_header
from util.llm import call_lm
from util.string import flip_numbers_in_string


def run(args):
    # configs
    num_run = args.num_run
    directed = args.directed
    num_nodes = args.num_node
    edge_rate = args.edge_rate
    len_shortest_path = args.len_shortest_path
    num_shot = args.num_shot
    model = args.model
    api_key = args.api_key
    target_candidates = args.target_candidates
    prev_try_maxlen = args.prev_try_maxlen
    max_try_ratio = args.max_try_ratio
    fwd_T = args.temperature
    back_T = args.temperature
    flip_T = args.temperature
    
    # Determine mode
    forward_only = False
    backward_only = False
    flip_only = False
    no_backward = False
    if args.mode == 'fwd':
        forward_only = True
    elif args.mode == 'back':
        backward_only = True
    elif args.mode == 'flip':
        flip_only = True
    elif args.mode == 'fwd-back':
        pass
    elif args.mode == 'fwd-flip':
        no_backward = True
    else:
        raise 'Unknown mode from args!'
    reorder_incident = directed and (flip_only or no_backward)

    # Prompt header and ender
    header = f"You will be given {'a directed' if directed else 'an undirected'} graph search problem with a few examples."
    fwd_ender = f"\n\nPlan the shortest path from initial to goal node for the this **{'directed' if directed else 'undirected'}** graph. Follow the format 'Shortest Path: (...)' and do not output anything else."
    back_ender = f"\n\nPlan the reversed shortest path from goal to initial node for the this **{'directed' if directed else 'undirected'}** graph. Follow the format 'Reversed Shortest Path: (...)' and do not output anything else."
    verify_ender = f"Remember the graph is {'directed' if directed else 'undirected'}. Follow the exact same format as the examples and check each options step by step. Begin with 'Checking each options step by step:'"
    reorder_ender = "\n\nRemember the edges are directed. Please re-order this directed graph with the exact same full procedure as the example. Follow the same format and do not output anything else."

    # Set up save path
    if forward_only:
        setting = "fwd_verify"
    elif backward_only:
        setting = "back_verify"
    elif flip_only:
        setting = "flip_verify"
    elif no_backward:
        setting = "fwd_flip_verify"
    else:
        setting = "fwd_back_verify"
    os.makedirs("result/graph", exist_ok=True)
    save_path = f"result/graph/{setting}_{directed}_{num_nodes}_{int(edge_rate * 100)}_{len_shortest_path}_{num_shot}_{num_run}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.pkl"

    # Run all
    success_all = []
    candidates_all = []
    verify_raw_all = []
    verify_sol_all = []
    graph_all = []
    for _ in tqdm(range(num_run)):
        fwd_prompt = header
        back_prompt = header
        flip_prompt = header
        fwd_T_trial = 0  # always start with deterministic
        back_T_trial = 0
        flip_T_trial = 0

        # Generate sampling prompt with few-shot examples
        for example_ind in range(num_shot):
            (
                fwd_graph_prompt,
                back_graph_prompt,
                flip_graph_prompt,
                fwd_graph_soln,
                back_graph_soln,
                G,
                _,
                _,
                _,
                _,
            ) = build_graph_prompt_fixed_length(
                len_shortest_path=len_shortest_path,
                num_nodes=num_nodes,
                edge_rate=edge_rate,
                directed=directed,
            )
            fwd_prompt += (
                f"\n\n** Example {example_ind+1} **\n"
                + fwd_graph_prompt
                + " "
                + fwd_graph_soln
            )
            back_prompt += (
                f"\n\n** Example {example_ind+1} **\n"
                + back_graph_prompt
                + " "
                + back_graph_soln
            )
            if reorder_incident:
                flip_prompt += (
                    f"\n\n** Example {example_ind+1} **\n"
                    + fwd_graph_prompt
                    + " "
                    + fwd_graph_soln
                )
            else:
                flip_prompt += (
                    f"\n\n** Example {example_ind+1} **\n"
                    + flip_graph_prompt
                    + " "
                    + back_graph_soln
                )
        (
            fwd_graph_prompt,
            back_graph_prompt,
            flip_graph_prompt,
            fwd_graph_soln,
            back_graph_soln,
            G,
            node_init,
            node_goal,
            fwd_solns,
            back_solns,
        ) = build_graph_prompt_fixed_length(
            len_shortest_path=len_shortest_path,
            num_nodes=num_nodes,
            edge_rate=edge_rate,
            directed=directed,
            add_path_prompt=False,
        )
        fwd_prompt_with_problem = (
            fwd_prompt + "\n\n** Current problem **\n" + fwd_graph_prompt + fwd_ender
        ).strip()
        back_prompt_with_problem = (
            back_prompt + "\n\n** Current problem **\n" + back_graph_prompt + back_ender
        ).strip()
        flip_prompt_with_problem = (
            flip_prompt + "\n\n** Current problem **\n" + flip_graph_prompt + fwd_ender
        ).strip()

        # get properties
        num_computations_fwd, num_computations_back = get_graph_property(
            G,
            node_init,
            node_goal,
            computations=True,
            use_bfs=True,  # or dijkstra
            directed=directed,
        )

        # reorder incident for flip - assume fewshot examples do not need to be re-ordered, essentially using forward examples
        if reorder_incident:
            print("reordering...")
            cur_graph_str = flip_graph_prompt.split("Initial")[0].strip()
            reorder_query = (
                reorder_incident_header
                + cur_graph_str
                + reorder_ender
            )
            reorder_raw_response = call_lm(
                reorder_query,
                api_key,
                model=model,
                max_tokens=1024,
                temperature=0,
            )[0]
            try:
                reorder_graph_str = "Node" + reorder_raw_response.split("Node", 1)[1]
            except:
                print(
                    f"Cannot parse reorder response: {reorder_raw_response}"
                )
                reorder_graph_str = reorder_raw_response
            flip_prompt_with_problem = flip_prompt_with_problem.replace(
                cur_graph_str, reorder_graph_str
            )

        # Sample multiple times
        max_try = max_try_ratio * target_candidates
        candidates_dir = []
        prev_try = deque(maxlen=prev_try_maxlen)
        for _ in range(max_try):
            try:
                # Determine which direction/prompt to use
                use_fwd = random.random() < 0.5
                if forward_only or (not backward_only and not flip_only and use_fwd):
                    print("Generating forward...")
                    direction = "fwd"
                    candidate = (
                        call_lm(
                            fwd_prompt_with_problem,
                            api_key,
                            model=model,
                            max_tokens=32,
                            temperature=fwd_T_trial,
                        )[0]
                        .rsplit("Shortest Path: ", 1)[1]
                        .strip()
                    )
                    fwd_T_trial = fwd_T  # update
                    prev_try.append(candidate)
                    if candidate not in [c[0] for c in candidates_dir]:
                        candidates_dir.append((candidate, direction))
                elif backward_only or (
                    not forward_only and not flip_only and not no_backward and not use_fwd
                ):
                    print("Generating backward...")
                    direction = "back"
                    rev_candidate = (
                        call_lm(
                            back_prompt_with_problem,
                            api_key,
                            model=model,
                            max_tokens=32,
                            temperature=back_T_trial,
                        )[0]
                        .rsplit("Reversed Shortest Path: ", 1)[1]
                        .strip()
                    )
                    back_T_trial = back_T
                    rev_candidate = flip_numbers_in_string(rev_candidate)
                    prev_try.append(rev_candidate)
                    if rev_candidate not in [c[0] for c in candidates_dir]:
                        candidates_dir.append((rev_candidate, direction))
                elif flip_only or (
                    not forward_only and not backward_only and no_backward and not use_fwd
                ):
                    print("Generating flipped...")
                    direction = "flip"
                    rev_candidate = (
                        call_lm(
                            flip_prompt_with_problem,
                            api_key,
                            model=model,
                            max_tokens=32,
                            temperature=flip_T_trial,
                        )[0]
                        .rsplit("Shortest Path: ", 1)[1]
                        .strip()
                    )
                    flip_T_trial = flip_T
                    rev_candidate = flip_numbers_in_string(rev_candidate)
                    prev_try.append(rev_candidate)
                    if rev_candidate not in [c[0] for c in candidates_dir]:
                        candidates_dir.append((rev_candidate, direction))
                else:
                    raise NotImplementedError
                if len(candidates_dir) == target_candidates:
                    break
                # break if all in prev_try are the same
                if len(prev_try) == prev_try_maxlen and len(set(prev_try)) == 1:
                    break
            except:
                print(f"Error in parsing the output.")
                continue
        random.shuffle(candidates_dir)
        candidates = [c[0] for c in candidates_dir]

        # Self-verification
        labels = ["A", "B", "C", "D"]
        verify_examples_prompt = header
        for example_ind in range(args.num_shot):
            example = build_graph_prompt_verify(
                len_shortest_path=len_shortest_path,
                num_nodes=num_nodes,
                edge_rate=edge_rate,
                directed=directed,
            )
            verify_examples_prompt += f"\n\n** Example {example_ind+1} **\n" + example
        verify_prompt = (
            verify_examples_prompt
            + "\n\n** Current problem **\n"
            + fwd_graph_prompt.rsplit("Shortest Path:", 1)[0]
            + "\nWhich one is the correct shortest path?\n"
        )
        for i in range(len(candidates)):
            verify_prompt += labels[i] + ". " + candidates[i] + "\n"
        verify_prompt += verify_ender
        raw_verify = call_lm(
            verify_prompt,
            api_key,
            model=model,
            max_tokens=1024,
            temperature=0,
            stop=["Print", "Since"],  # tend to ramble from Since...
        )[0].strip()
        try:
            verify_label = raw_verify.split("Thus")[1].strip()
        except:
            verify_label = "None"
            print(f"Error: {verify_label}")
        possible_labels = []
        for label in labels[: len(candidates)]:
            if label in verify_label:  # assume there is no other capital letter
                possible_labels.append(label)
        if len(possible_labels) == 0:
            print(
                f"Unknown verify answer (choose random one): {verify_label}"
            )
            verify_label = random.choice(labels[: len(candidates)])
        else:
            verify_label = random.choice(possible_labels)
        
        # Check answer
        verify_sol = candidates[labels.index(verify_label)]
        success_all.append(verify_sol in fwd_solns)
        print(
            f"==== FINAL ====\ncandidates {candidates}, chosen {verify_sol}, sols {fwd_solns}, success: {verify_sol in fwd_solns}"
        )

        # Save
        candidates_all.append(candidates_dir)  # not shuffled
        verify_sol_all.append(verify_sol)
        verify_raw_all.append(raw_verify)
        graph_all.append(
            {
                'G': G,
                'node_init': node_init,
                'node_final': node_goal,
                'fwd_solns': fwd_solns,
                'back_soln': back_solns,
                'fwd_prompt': fwd_prompt,
                'back_prompt': back_prompt,
                'flip_prompt': flip_prompt,
                'fwd_solns': fwd_solns,
                'num_computations_fwd': num_computations_fwd,
                'num_computations_back': num_computations_back,
            }
        )

    # save and print results
    with open(save_path, "wb") as f:
        pickle.dump(
            {
                "successes": success_all,
                "candidates": candidates_all,
                "verify_sol_all": verify_sol_all,
                "verify_raw_all": verify_raw_all,
                "graphs": graph_all,
            },
            f,
        )
    print(f"saved files to: {save_path}")
    print('Success rate:', np.mean(success_all))

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # Baseline choice
    parser.add_argument("--mode", type=str, default='fwd')
    # Graph structures
    parser.add_argument("--num_node", type=float, default=12)
    parser.add_argument("--edge_rate", type=float, default=0.2)
    parser.add_argument("--directed", action='store_true')
    parser.add_argument("--len_shortest_path", type=int, default=5)
    # Your api key
    parser.add_argument("--api_key", type=str, default="API_KEY")
    # Fixed
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--model", type=str, default="gpt-4o-2024-05-13")
    parser.add_argument("--temperature", type=float, default=0.5)
    parser.add_argument("--num_shot", type=int, default=3)
    parser.add_argument("--target_candidates", type=int, default=4)
    parser.add_argument("--prev_try_maxlen", type=int, default=4)
    parser.add_argument("--max_try_ratio", type=int, default=2)
    args = parser.parse_args()
    
    # Warning
    if args.api_key == "API_KEY":
        raise "Please set the API key!"
    
    # Print all arguments
    print(f"Arguments:")
    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")
    
    # Run
    print("Running...")    
    run(args)