import os
import json
import logging
import datetime
import pickle
import random
from tqdm import tqdm

from util.llm import call_lm


def run(args):
    # Config
    num_run = args.num_run
    model = args.model
    api_key = args.api_key
    fwd_head_T = args.temperature
    back_head_T = args.temperature
    flip_head_T = args.temperature
    fwd_plan_T = args.temperature
    back_plan_T = args.temperature
    flip_plan_T = args.temperature
    num_try = args.num_try
    task = "task_1_plan_generation"

    # Determine mode
    forward_only = False
    backward_only = False
    flip_only = False
    no_backward = False
    if args.mode == 'fwd':
        forward_only = True
    elif args.mode == 'back':
        backward_only = True
    elif args.mode == 'flip':
        flip_only = True
    elif args.mode == 'fwd-back':
        pass
    elif args.mode == 'fwd-flip':
        no_backward = True
    else:
        raise 'Unknown mode from args!'

    # Load prompts
    verify_header_path = "plan-bench/prompts/verify_header.txt"
    verify_header = open(verify_header_path, "r").read()
    domain_fwd = "blocksworld_fwd_cot_verify_head"
    domain_fwd_plan = "blocksworld_fwd_cot_verify_plan"
    domain_back = domain_fwd  # same header for fwd and back
    domain_back_plan = "blocksworld_back_cot_verify_plan"
    domain_flip = "blocksworld_flip_cot_verify_head"
    domain_flip_plan = "blocksworld_flip_cot_verify_plan"
    problems_fwd = json.load(
        open(f"plan-bench/prompts/{domain_fwd}/{task}.json", "r")
    )["instances"]
    problems_plan_fwd = json.load(
        open(
            f"plan-bench/prompts/{domain_fwd_plan}/{task}.json",
            "r",
        )
    )["instances"]
    problems_back = json.load(
        open(f"plan-bench/prompts/{domain_back}/{task}.json", "r")
    )["instances"]
    problems_plan_back = json.load(
        open(
            f"plan-bench/prompts/{domain_back_plan}/{task}.json",
            "r",
        )
    )["instances"]
    problems_flip = json.load(
        open(f"plan-bench/prompts/{domain_flip}/{task}.json", "r")
    )["instances"]
    problems_plan_flip = json.load(
        open(
            f"plan-bench/prompts/{domain_flip_plan}/{task}.json",
            "r",
        )
    )["instances"]
    verify_ender = "\n\nPlease carefully verify whether the plan above is correct or not following the exact same format as the examples and the tips."

    # Set up save path
    if forward_only:
        setting = "fwd"
    elif backward_only:
        setting = "back"
    elif flip_only:
        setting = "flip"
    elif no_backward:
        setting = "fwd_flip"
    else:
        setting = "fwd_back"
    os.makedirs("result/blocksworld", exist_ok=True)
    save_path = f"result/blocksworld/{setting}_{task}_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.pkl"

    # Run all
    results = []
    for ind in tqdm(range(num_run)):
        
        # get problem
        problem_fwd = problems_fwd[ind]
        problem_plan_fwd = problems_plan_fwd[ind]
        problem_back = problems_back[ind]
        problem_plan_back = problems_plan_back[ind]
        problem_flip = problems_flip[ind]
        problem_plan_flip = problems_plan_flip[ind]
        query_fwd = problem_fwd["query"]
        ground_truth_plan_fwd = problem_fwd["ground_truth_plan"]
        instance_id = problem_fwd["instance_id"]

        # try deterministic first
        fwd_head_T_trial = 0
        back_head_T_trial = 0
        flip_head_T_trial = 0
        fwd_plan_T_trial = 0
        back_plan_T_trial = 0
        flip_plan_T_trial = 0

        # try multiple times
        gen_responses = []
        plan_responses = []
        plan_queries = []
        verify_responses = []
        verify_queries = []
        extracted_plans = []
        checks = []
        directions = []
        for _ in range(num_try):
            
            # summarize states
            fwd_gen_response = None
            back_gen_response = None
            flip_gen_response = None
            while 1:
                use_fwd = random.random() < 0.5
                if forward_only or (not backward_only and not flip_only and use_fwd):
                    logging.info("Generating forward...")
                    direction = "fwd"
                    if fwd_gen_response is None:
                        gen_response, _ = call_lm(
                            problem_fwd["query"],
                            api_key,
                            model=model,
                            max_tokens=1024,
                            temperature=fwd_head_T_trial,
                            stop=["END"],
                        )
                        fwd_head_T_trial = fwd_head_T
                    else:
                        gen_response = fwd_gen_response
                elif backward_only or (
                    not no_backward and not forward_only and not flip_only and not use_fwd
                ):
                    logging.info("Generating backward...")
                    direction = "back"
                    if back_gen_response is None:
                        gen_response, _ = call_lm(
                            problem_back["query"],
                            api_key,
                            model=model,
                            max_tokens=1024,
                            temperature=back_head_T_trial,
                            stop=["END"],
                        )
                        back_head_T_trial = back_head_T
                    else:
                        gen_response = back_gen_response
                elif flip_only or (
                    not forward_only and not backward_only and no_backward and not use_fwd
                ):
                    logging.info("Generating flipped...")
                    direction = "flip"
                    if flip_gen_response is None:
                        gen_response, _ = call_lm(
                            problem_flip["query"],
                            api_key,
                            model=model,
                            max_tokens=1024,
                            temperature=flip_head_T_trial,
                            stop=["END"],
                        )
                        flip_head_T_trial = flip_head_T
                    else:
                        gen_response = flip_gen_response
                else:
                    raise NotImplementedError
                gen_responses.append(gen_response)
                directions.append(direction)

                # get states
                if direction == "flip":
                    init_split = "new init state: "
                    goal_split = "new goal state: "
                else:
                    init_split = ": "  # init state: or init state (each clause is a stack):
                    # some LM cannot follow (each clause is a stack) format
                    goal_split = "goal: "
                try:
                    plan_init = gen_response.split(init_split, 1)[1].split("\n", 1)[0]
                    plan_goal = gen_response.split(goal_split)[1].split("\n", 1)[0]
                    if direction == "flip":
                        verify_init = gen_response.split("stack): ")[1].split("\n", 1)[0]
                        verify_goal = gen_response.split("goal: ")[1].split("\n", 1)[0]
                    else:
                        verify_init = plan_init
                        verify_goal = plan_goal

                    # do not set responses!
                    if direction == "fwd":
                        fwd_gen_response = gen_response
                    elif direction == "back":
                        back_gen_response = gen_response
                    elif direction == "flip":
                        flip_gen_response = gen_response
                    break
                except:
                    logging.info("No new init and goal states found. Try again.")

            # plan steps
            if direction == "fwd":
                query_plan = problem_plan_fwd["query"]
                plan_T = fwd_plan_T_trial
            elif direction == "back":
                query_plan = problem_plan_back["query"]
                plan_T = back_plan_T_trial
            elif direction == "flip":
                query_plan = problem_plan_flip["query"]
                plan_T = flip_plan_T_trial
            query_plan = query_plan.rsplit("[STATEMENT]", 1)[0] + "[STATEMENT]\n"
            if direction == "back":
                query_plan += f"init state (each clause is a stack): {plan_init}\ngoal: {plan_goal}\n\nPlease follow the format and generate your reversed plan for the current problem. Start with [REVERSED PLAN]"
            else:
                query_plan += f"init state (each clause is a stack): {plan_init}\ngoal: {plan_goal}\n\nPlease follow the format and generate your plan for the current problem. Start with [PLAN]"
            plan_response, _ = call_lm(
                query_plan,
                api_key,
                model=model,
                max_tokens=1024,
                temperature=plan_T,
                stop=["END"],
            )
            plan_responses.append(plan_response)
            plan_queries.append(query_plan)
            if direction == "fwd":
                fwd_plan_T_trial = fwd_plan_T
            elif direction == "back":
                back_plan_T_trial = back_plan_T
            elif direction == "flip":
                flip_plan_T_trial = flip_plan_T

            # extract plan steps
            try:
                plan_steps = [
                    l for l in plan_response.split("\n") if l[0] not in ["(", "[", "="]
                ]  # keep only the steps, not the intermediate states
                if (
                    direction == "flip"
                ):
                    flip_plan_steps = []
                    for step in plan_steps[::-1]:
                        if "pick up" in step:
                            step = step.replace("pick up", "put down")
                        elif "put down" in step:
                            step = step.replace("put down", "pick up")
                        elif "unstack" in step:
                            step = step.replace("unstack", "stack").replace("from ", "")
                        elif "stack" in step:
                            step = step.replace("stack", "unstack").replace("on", "from on")
                        flip_plan_steps.append(step)
                    extracted_plan = "\n".join(flip_plan_steps)
                elif direction == "back":  # reverse the plan
                    extracted_plan = "\n".join(plan_steps[::-1])
                else:
                    extracted_plan = "\n".join(plan_steps)
            except:
                logging.info(
                    f"Cannot parse plan steps from the response: {plan_response}. Use empty string instead."
                )
                extracted_plan = ""
            extracted_plans.append(extracted_plan)

            # self-verify
            header = query_plan.split("[STATEMENT]")[-1].split("[PLAN]")[0].strip()
            header = header.replace("Your goal", "My goal")
            header = (
                "init state (each clause is a stack): "
                + verify_init
                + "\ngoal: "
                + verify_goal
                + "\n\nMy plan is as follows:"
            )
            verify_query = (
                verify_header
                + "\n\n[PLAN]\n"
                + header
                + "\n"
                + extracted_plan
                + verify_ender
            )
            verify_response, _ = call_lm(
                verify_query,
                api_key,
                model=model,
                max_tokens=1024,
                temperature=0,
                stop=["END", "[PLAN]"],
            )
            check = not "incorrect" in verify_response
            verify_responses.append(verify_response)
            verify_queries.append(verify_query)
            checks.append(check)
            if check:
                logging.info(
                    f"======== Success! The generated plan is correct! ========"
                )
                break

        # save
        results.append(
            {
                "task": task,
                "instance_id": instance_id,
                "gen_responses": gen_responses,
                "plan_responses": plan_responses,
                "plan_queries": plan_queries,
                "verify_responses": verify_responses,
                "verify_queries": verify_queries,
                "extracted_plans": extracted_plans,
                "checks": checks,
                "directions": directions,
            }
        )

    # save
    with open(save_path, "wb") as f:
        pickle.dump(results, f)
    print("Saved to", save_path)
    print(f"Please run evaluation with 'python plan-bench/evaluate.py --mode {args.mode} --result {save_path}'!")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # Baseline choice
    parser.add_argument("--mode", type=str, default='fwd')
    # Your api key
    parser.add_argument("--api_key", type=str, default="API_KEY")
    # Fixed
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--model", type=str, default="gpt-4o-2024-05-13")
    parser.add_argument("--temperature", type=float, default=0.5)
    parser.add_argument("--num_try", type=int, default=6)
    args = parser.parse_args()
    
    # Warning
    if args.api_key == "API_KEY":
        raise "Please set the API key!"
    
    # Print all arguments
    print(f"Arguments:")
    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")
    
    # Run
    print("Running...")    
    run(args)