import os
import datetime
import pickle
import numpy as np
import random
from tqdm import tqdm
from collections import deque

from util.llm import call_lm
from arrayt.make_graph import (
    generate_pcfg_graph,
    str_to_fns,
    pcfg_reverse,
    pcfg_shift_left,
    pcfg_shift_right,
    pcfg_swap,
    pcfg_repeat,
    pcfg_cut,
)


def run(args):
    # configs
    num_run = args.num_run
    funcs = args.funcs
    num_fns = args.num_fns
    array_size = args.array_size
    num_shot = args.num_shot
    model = args.model
    api_key = args.api_key
    prev_try_maxlen = args.prev_try_maxlen
    fwd_T = args.temperature
    back_T = args.temperature
    flip_T = args.temperature
    num_try = args.num_try
    max_repeat = args.max_repeat  # only one repeat to make prompt not too long
    len_shortest_path = num_fns

    # Determine mode
    forward_only = False
    backward_only = False
    flip_only = False
    no_backward = False
    if args.mode == 'fwd':
        forward_only = True
    elif args.mode == 'back':
        backward_only = True
    elif args.mode == 'flip':
        flip_only = True
    elif args.mode == 'fwd-back':
        pass
    elif args.mode == 'fwd-flip':
        no_backward = True
    else:
        raise 'Unknown mode from args!'

    # Determine the functions involved
    if funcs == 'reverse_swap_repeat_cut':
        run_funcs = [pcfg_reverse, pcfg_swap, pcfg_repeat, pcfg_cut]
        from arrayt.prompts_reverse_swap_repeat_cut import (
            verify_prompt_fixed,
            fwd_prompt_header,
            back_prompt_header,
            flip_prompt_header,
        )
    elif funcs == 'reverse_swap_shift':
        run_funcs = [pcfg_reverse, pcfg_swap, pcfg_shift_left, pcfg_shift_right]
        from arrayt.prompts_reverse_swap_shift import (
            verify_prompt_fixed,
            fwd_prompt_header,
            back_prompt_header,
            flip_prompt_header,
        )  
    elif funcs == 'shift_repeat_cut':
        run_funcs = [pcfg_shift_left, pcfg_shift_right, pcfg_repeat, pcfg_cut]
        from arrayt.prompts_shift_repeat_cut import (
            verify_prompt_fixed,
            fwd_prompt_header,
            back_prompt_header,
            flip_prompt_header,
        )
    build_funcs = run_funcs
    vocab = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]

    # Prompt header and ender
    plan_ender = "\nPlease solve with the exact same format. Do not repeat the problem."
    verify_ender = "\nPlease verify initial to final steps with the exactly same format. Do not repeat the problem."

    # Set up save path
    if forward_only:
        setting = "fwd_verify"
    elif backward_only:
        setting = "back_verify"
    elif flip_only:
        setting = "flip_verify"
    elif no_backward:
        setting = "fwd_flip_verify"
    else:
        setting = "fwd_back_verify"
    os.makedirs("result/arrayt", exist_ok=True)
    save_path = f"result/arrayt/{setting}_{num_fns}_{array_size}_{num_shot}_{num_run}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.pkl"

    # Run all
    success_all = []
    graph_all = []
    candidates_all = []
    for _ in tqdm(range(num_run)):
        fwd_T_trial = 0  # try deterministic first
        back_T_trial = 0
        flip_T_trial = 0

        # Sample fewshots
        fwd_examples = []
        back_examples = []
        flip_examples = []
        for _ in range(num_shot):
            while 1:
                vars_init = np.random.choice(
                    vocab,
                    size=array_size,
                    replace=True,
                ).tolist()
                (
                    G,
                    _,
                    _,
                    _,
                    _,
                    _,
                    _,
                    _,
                    _,
                    _,
                    _,
                    fwd_example,
                    back_example,
                    flip_example,
                ) = generate_pcfg_graph(
                    num_fns=num_fns,
                    required_path_length=len_shortest_path,
                    vars_init=vars_init,
                    max_repeat=max_repeat,
                    build_funcs=build_funcs,
                    run_funcs=run_funcs,
                )
                if G is not None:
                    break
            fwd_examples.append(fwd_example)
            back_examples.append(back_example)
            flip_examples.append(flip_example)

        # Sample current problem
        while 1:
            vars_init = np.random.choice(
                vocab,
                size=array_size,
                replace=True,
            ).tolist()
            (
                G,
                G_rev,
                num_computations_fwd,
                num_computations_back,
                node_init,
                node_goal,
                problem,
                problem_flip,
                fns,
                vars_init,
                vars_final,
                _,
                _,
                _,
            ) = generate_pcfg_graph(
                num_fns=num_fns,
                required_path_length=len_shortest_path,
                vars_init=vars_init,
                max_repeat=max_repeat,
                build_funcs=build_funcs,
                run_funcs=run_funcs,
            )
            if G is not None:
                break

        # Call LLM forward/backward, verify
        prev_tries = deque(maxlen=prev_try_maxlen)
        assert not (forward_only and backward_only)
        candidates_dir = []
        for _ in range(num_try):
            use_fwd = random.random() < 0.5
            if forward_only or (not backward_only and not flip_only and use_fwd):
                print("Generating forward...")
                direction = "fwd"
                query = (
                    fwd_prompt_header
                    + "\n\n".join(fwd_examples)
                    + "\n\n***** Current problem:\n"
                    + problem
                    + plan_ender
                )
                response, _ = call_lm(
                    query,
                    api_key,
                    model=model,
                    max_tokens=150,
                    temperature=fwd_T_trial,
                )
                fwd_T_trial = fwd_T
            elif backward_only or (
                not no_backward and not forward_only and not flip_only and not use_fwd
            ):
                print("Generating backward...")
                direction = "back"
                query = (
                    back_prompt_header
                    + "\n\n".join(back_examples)
                    + "\n\n***** Current problem:\n"
                    + problem
                    + plan_ender
                )
                response, _ = call_lm(
                    query,
                    api_key,
                    model=model,
                    max_tokens=150,
                    temperature=back_T_trial,
                )
                back_T_trial = back_T
            elif flip_only or (
                not forward_only and not backward_only and no_backward and not use_fwd
            ):
                print("Generating flipped...")
                direction = "flip"
                query = (
                    flip_prompt_header
                    + "\n\n".join(flip_examples)
                    + "\n\n***** Current problem:\n"
                    + problem_flip
                    + plan_ender
                )
                response, _ = call_lm(
                    query,
                    api_key,
                    model=model,
                    max_tokens=150,
                    temperature=flip_T_trial,
                )
                flip_T_trial = flip_T
            else:
                raise NotImplementedError

            # Extract functions
            if "functions: " not in response.lower():
                print("response not finished")
                continue
            answer = response.lower().split("functions: ")[-1].split("\n")[0].strip("[]")
            str_fns = tuple(answer.split(", "))
            if direction in ["back", "flip"]:
                str_fns = str_fns[::-1]
                if direction == "flip":
                    str_fns = tuple(
                        s.replace("shift_left", "shift_left_left") for s in str_fns
                    )
                    str_fns = tuple(
                        s.replace("shift_right", "shift_right_right") for s in str_fns
                    )
                    str_fns = tuple(
                        s.replace("shift_left_left", "shift_right") for s in str_fns
                    )
                    str_fns = tuple(
                        s.replace("shift_right_right", "shift_left") for s in str_fns
                    )
            prev_tries.append(str_fns)
            candidates_dir.append([str_fns, direction, "incorrect"])

            # Self check - skip if already checked before
            if str_fns not in list(prev_tries)[:-1]:
                query = (
                    verify_prompt_fixed
                    + "\n***** Current problem:\n"
                    + problem
                    + f"\nFunctions: {'[' + ', '.join(str_fns) + ']'}"
                    + verify_ender
                )
                response_verify, _ = call_lm(
                    query,
                    api_key,
                    model=model,
                    max_tokens=1024,
                    temperature=0,
                )
                if "Correct" in response_verify:
                    candidates_dir[-1][-1] = "correct"
                    break

            # break if all in prev_try are the same
            if len(prev_tries) == prev_try_maxlen and len(set(prev_tries)) == 1:
                break

        print(f"Attempts: {prev_tries}")

        # Check with ground truth - use last response
        response_fns = str_to_fns(response, flip=direction == "flip")
        if direction in ["back", "flip"]:
            response_fns = response_fns[::-1]
        vars = vars_init
        for fn in response_fns:
            vars = fn(vars)
        vars_response = vars
        print(f"RESPONSE VARS: {vars_response}")
        print(f"TRUE VARS: {vars_final}")
        print(f"SUCCESS: {vars_response == vars_final}")

        # Save
        success_all.append(vars_response == vars_final)
        candidates_all.append(candidates_dir)
        graph_all.append(
            {
                "G": G,
                "G_rev": G_rev,
                "node_init": node_init,
                "node_goal": node_goal,
                "fns": fns,
                "vars_init": vars_init,
                "vars_final": vars_final,
                'num_computations_fwd': num_computations_fwd,
                'num_computations_back': num_computations_back,
            }
        )

    # Save all
    with open(save_path, "wb") as f:
        pickle.dump(
            {
                "successes": success_all,
                "candidates": candidates_all,
                "graphs": graph_all,
            },
            f,
        )
    print(f"saved files to: {save_path}")
    print('Success rate:', np.mean(success_all))


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # Baseline choice
    parser.add_argument("--mode", type=str, default='fwd')
    # Array structures
    parser.add_argument("--num_fns", type=int, default=3)
    parser.add_argument("--funcs", type=str, default='reverse_swap_repeat_cut')
    parser.add_argument("--array_size", type=int, default=4)
    # Your api key
    parser.add_argument("--api_key", type=str, default="API_KEY")
    # Fixed
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--model", type=str, default="gpt-4o-2024-05-13")
    parser.add_argument("--temperature", type=float, default=0.5)
    parser.add_argument("--num_shot", type=int, default=3)
    parser.add_argument("--num_try", type=int, default=6)
    parser.add_argument("--prev_try_maxlen", type=int, default=4)
    parser.add_argument("--max_repeat", type=int, default=1)
    args = parser.parse_args()
    
    # Warning
    if args.api_key == "API_KEY":
        raise "Please set the API key!"
    
    # Print all arguments
    print(f"Arguments:")
    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")
    
    # Run
    print("Running...")    
    run(args)