import argparse
import ast
from functools import partial
from pathlib import Path
from dotenv import load_dotenv
from pddl.core import And, Domain
import json
import re
from demos.scripts.utils import gen_domain_nl, gen_problem_nl

from tp_lodge.motion_planning.dummy_motion_validator import PDDLDomain
from demos.ipc.src.common.function_mapping import map_functions
from tp_lodge.random_walk.random_walk_env import RandomWalkEnv
from tp_lodge.task_planning.pddl_planner.ai_planner.fastdownward_ai_planner import FastDownwardAIPlanner
from tp_lodge.task_planning.pddl_planner.ai_planner.hierarchical_ai_planner import HierarchicalAIPlanner
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from typing import Callable, Dict, Optional, Tuple
from joblib import Parallel, delayed
from natsort import natsorted
import pandas as pd
from pddl.parser.problem import ProblemParser
from pddl.parser.domain import DomainParser

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.system_message import SystemMessage
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from llm_utils.textgen_api.textgen_api import TextGenApi
from tqdm import tqdm
from demos.ipc.src.common.utils import get_function_mapping_from_config, get_task_centric_domain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from python_utils.string_utils import wrap_code, get_markup_from_text

from demos.ipc.src.sadegh.utils import (
    ONE_SHOT_PROBLEM_TRANSLATION_PROMPT,
    PROBLEM_TRANSLATION_SYSTEM_MESSAGE,
    ZERO_SHOT_PROBLEM_TRANSLATION_PROMPT,
)
from tp_lodge.random_walk.random_walk_evaluator import RandomWalkEvaluator, harmonic_mean
from tp_lodge.utils.pddl_lib_utils import copy_action_w_args, copy_domain_w_args, copy_problem_w_args
from tp_lodge.utils.pddl_utils import get_predicate_evaluation, get_valid_predicates
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.models.sas.sas_action import SasAction


def sadegh_generate_problem(
    textgen_api: TextGenApi,
    domain_pddl: str,
    domain_nl: str,
    target_problem_nl: str,
    target_problem_template: str,
    *,
    context_problem_nl: Optional[str] = None,
    context_problem_pddl: Optional[str] = None,
):
    kwargs = dict(
        domain_nl=wrap_code(domain_nl, lang="markdown"),
        domain_pddl=wrap_code(domain_pddl, lang="pddl"),
        target_problem_nl=wrap_code(target_problem_nl, lang="markdown"),
        target_problem_template_pddl=wrap_code(target_problem_template, lang="pddl"),
    )
    if context_problem_pddl is not None:
        context_problem_template = str(copy_problem_w_args(ProblemParser()(context_problem_pddl), init=[], goal=And()))
        prompt = ONE_SHOT_PROBLEM_TRANSLATION_PROMPT.format(
            **kwargs,
            context_problem_pddl=wrap_code(context_problem_pddl, lang="pddl"),
            context_problem_nl=wrap_code(context_problem_nl, lang="markdown"),
            context_problem_template_pddl=wrap_code(context_problem_template, lang="pddl"),
        )
    else:
        prompt = ZERO_SHOT_PROBLEM_TRANSLATION_PROMPT.format(**kwargs)

    chat = Chat(
        messages=[
            SystemMessage(content=[TextMessageContent(PROBLEM_TRANSLATION_SYSTEM_MESSAGE)]),
            UserMessage(content=[TextMessageContent(prompt)]),
        ]
    )
    response = textgen_api.do_call(chat)
    assert isinstance(response.content[0], TextMessageContent), "Expected a TextMessageContent response"
    gpt_output = response.content[0].text
    generated_pddl = "\n".join(get_markup_from_text(gpt_output, markup=["pddl"]))

    try:
        ProblemParser()(generated_pddl)

    except Exception as e:
        print(f"Error parsing generated PDDL\n{e}\nRetrying...")
        return sadegh_generate_problem(
            textgen_api,
            domain_pddl,
            domain_nl,
            target_problem_nl,
            target_problem_template,
            context_problem_nl=context_problem_nl,
            context_problem_pddl=context_problem_pddl,
        )

    return generated_pddl


def _eval_finds_plan(
    plan: SasPlan,
    *,
    target_env: RandomWalkEnv,
    gt_domain: str,
    gt_problem: str,
    gen_function_mapping: Dict[str, Callable[[SasAction], str]],
):
    try:
        gen_plan_mapped = [gen_function_mapping[action.name](action) for action in plan.actions]
        found_gt_plan = target_env.get_plan(gen_plan_mapped)
        if found_gt_plan is not None:
            std_out, success = AIValidator().validate(
                domain=gt_domain,
                problem=gt_problem,
                plan="\n".join(found_gt_plan),
                options="-v",
            )
        else:
            success = False
    except Exception as e:
        print(f"Error validating plan: {e}")
        success = False
    return success


def get_plan_for_lodge(lodge_dir: Path) -> SasPlan:
    plan_file = lodge_dir / "ai-plan.cache.plan"
    from tp_lodge.task_planning.models.sas.sas_plan import SasPlan

    plan = SasPlan.from_string(plan_file.read_text())
    hier_plan = []
    for idx, action in enumerate(plan.actions):
        sub_dir = lodge_dir / "sub-actions" / f"{idx}-{action.name}"
        if sub_dir.exists() and (sub_dir / "ai-plan.cache.plan").exists():
            hier_plan.extend(get_plan_for_lodge(sub_dir).actions)
        else:
            hier_plan.append(action)
    return SasPlan(hier_plan)


def eval_sample(
    problem_file: Path,
    gt_domain: str,
    problems_dir: Path,
    gt_function_mapping_config: dict,
    domain: str,
    variant: str,
    textgen_api: TextGenApi,
    gt_domain_nl: str,
    gt_function_mapping: dict,
    variant_result_dir: Path,
    functions: Dict[str, ast.FunctionDef],
    function_signatures: Dict[str, int],
    lodge_only_verified_operators: bool = True,
    generate_problem_for_lodge: bool = False,
):
    # print(f"Evaluating problem: {problem_file.name}")
    gt_problem = problem_file.read_text()

    match = re.match(r"p(\d{2})\.pddl", problem_file.name)
    assert match is not None
    idx = int(match.group(1))
    task_name = "task-%d" % (idx - 1)

    exp_problem_dir = variant_result_dir / "problem"
    exp_problem_dir.mkdir(exist_ok=True)
    gen_problem_file = exp_problem_dir / problem_file.name
    gen_pddl_domain = None
    if variant == "lodge":
        task_dir = variant_result_dir / task_name
        shared_domain = False
        if not task_dir.exists():
            # print("Using same domain for all tasks")
            shared_domain = True
            task_dir = task_dir.parent
        assert task_dir.exists(), f"Task directory {task_dir} does not exist."

        textgen_file = task_dir / "textgen-api-usage.json"
        if textgen_file.is_file():
            textgen_usage = json.loads(textgen_file.read_text())
            df = pd.DataFrame(textgen_usage["calls"])
            if "call_id" not in df.columns:
                print(f"Call ID not found in textgen API usage for task {task_name}. Skipping.")
                raise ValueError("Call ID not found in textgen API usage.")
            df["call_id"] = df["call_id"].apply(lambda x: x if x is not None else "unknown")
            tokens = df.sum()
        else:
            tokens = {}

        func_mapping_file = task_dir / "function_mapping.json"
        assert func_mapping_file.exists(), f"Function mapping file not found for task {task_name}. Skipping."

        gen_function_mapping_config = json.loads(func_mapping_file.read_text())
        gen_pddl_domain = PDDLDomain.loads((task_dir / "domain.json").read_text())
        if lodge_only_verified_operators:
            gen_pddl_domain = gen_pddl_domain.only_verified_operators(remove_unused_predicates=False)
        gen_domain = str(DomainParser()((task_dir / "domain.pddl").read_text()))  # non-hierarchical domain

        if not shared_domain:
            gen_pddl_problem = PDDLProblem.loads((task_dir / "generated-problem.json").read_text())
            assert isinstance(gen_pddl_problem, PDDLProblem)
            gen_problem = str(gen_pddl_problem.to_pddl())
            in_context_problem_nl = None
        else:
            # FIXME: hack for furniture bench, since there we the gt problem aligns with the lodge one
            gen_pddl_problem = PDDLProblem.loads((task_dir / "problem.json").read_text())
            assert isinstance(gen_pddl_problem, PDDLProblem)
            gen_problem = str(gen_pddl_problem.to_pddl())
            generate_problem_for_lodge = True
            in_context_problem_nl = str(gen_problem)
            # gen_problem = str(
            #     copy_problem_w_args(
            #         ProblemParser()(problem_file.read_text()), domain_name=gen_pddl_domain.to_pddl().name
            #     )
            # )
        if not (task_dir / "failed").is_file():
            gen_plan = get_plan_for_lodge(task_dir)
        else:
            gen_plan = None
    elif variant == "sadegh":
        summary_log_path = variant_result_dir / "summary_logs.json"
        sadegh_exp_log = json.loads(Path(summary_log_path).read_text())

        tokens = {
            v: sadegh_exp_log["summary_metrics"][k]
            for k, v in {"used_prompt_tokens": "input_tokens", "used_completion_tokens": "output_tokens"}.items()
        }

        in_context = sadegh_exp_log["aux"]["problem_candidates_aux"][sadegh_exp_log["aux"]["best_candidate_idx"]]
        gen_domain = (variant_result_dir / "domain.pddl").read_text()

        gen_problem = in_context["gen_problem_pddl"]
        in_context_problem_nl = gen_problem_nl((problems_dir / "p01.pddl").read_text(), data_dir=problems_dir.parent)
        gen_function_mapping_config = gt_function_mapping_config
        gen_problem = None
        gen_plan = None
    elif variant == "cluster":
        gen_domain_file = variant_result_dir / "domain.pddl"
        gen_domain = gen_domain_file.read_text()
        in_context_problem_nl = None
        gen_domain_obj = DomainParser()(gen_domain)
        gen_problem = str(
            copy_problem_w_args(ProblemParser()(gen_problem_file.read_text()), domain_name=gen_domain_obj.name)
        )
        gen_plan = None

        gen_function_mapping_config = {}
        updated_ops = []
        for operator in list(gen_domain_obj.actions):
            preds = (
                operator.precondition.operands if isinstance(operator.precondition, And) else [operator.precondition]
            )
            found = False
            updated_preds = []
            for pred in preds:
                pred_name = pred.name if hasattr(pred, "name") else pred.term.predicate
                if operator.name.startswith(pred_name):
                    assert pred_name in functions
                    fct = functions[pred_name]
                    found = True
                    # action predicate

                    arg_mapping = []
                    for arg in operator.parameters:
                        try:
                            idx = list(pred.terms).index(arg)
                            arg_mapping.append(idx)
                        except ValueError:
                            arg_mapping.append(None)
                    assert len(fct.args.args) == len([a for a in arg_mapping if a is not None])
                    gen_function_mapping_config[operator.name] = {"name": pred_name, "arg_mapping": arg_mapping}
                    break
                else:
                    updated_preds.append(pred)
            assert found
            updated_ops.append(copy_action_w_args(operator, precondition=And(*updated_preds)))

        gen_domain_obj = copy_domain_w_args(gen_domain_obj, actions=updated_ops)
        gen_domain = str(gen_domain_obj)

        tokens = {}
    elif variant == "lionel":
        gen_domain_file = variant_result_dir / "domain.pddl"
        gen_domain = gen_domain_file.read_text()
        gen_problem = None
        in_context_problem_nl = None

        textgen_usage = json.loads((variant_result_dir / "textgen-usage.json").read_text())
        df = pd.DataFrame(textgen_usage["calls"])
        df["call_id"] = df["call_id"].apply(lambda x: x if x is not None else "unknown")
        tokens = df.sum()

        # gen_function_mapping_config = gt_function_mapping_config
        func_mapping_file = variant_result_dir / "function_mapping.json"
        gen_function_mapping_config = json.loads(func_mapping_file.read_text())
        gen_problem = None
        gen_plan = None
        # print_mapping(DomainParser()(gen_domain), gen_function_mapping_config, functions)
    elif variant == "guan":
        gen_domain_file = variant_result_dir / "domain.pddl"
        gen_domain = gen_domain_file.read_text()
        gen_problem = None
        in_context_problem_nl = None

        textgen_usage = json.loads(list(variant_result_dir.glob("textgen-api-usage*.json"))[0].read_text())
        df = pd.DataFrame(textgen_usage["calls"])
        df["call_id"] = df["call_id"].apply(lambda x: x if x is not None else "unknown")
        tokens = df.sum()

        # gen_function_mapping_config = gt_function_mapping_config
        func_mapping_file = variant_result_dir / "function_mapping.json"
        gen_function_mapping_config = json.loads(func_mapping_file.read_text())
        gen_problem = None
        gen_plan = None
        # print_mapping(DomainParser()(gen_domain), gen_function_mapping_config, functions)
    else:
        raise ValueError(f"Unknown variant: {variant}")

    if variant in ["guan", "lionel", "sadegh"] or generate_problem_for_lodge:
        if not gen_problem_file.exists():
            # get generated problem.pddl
            print(f"Generating problem for {problem_file.name} using {variant} variant...")
            if gen_problem is not None:
                init = get_valid_predicates(ProblemParser()(gen_problem).init)
            else:
                init = []
            problem_template = str(copy_problem_w_args(ProblemParser()(gt_problem), init=init, goal=And()))
            gen_problem = sadegh_generate_problem(
                textgen_api=textgen_api,
                domain_pddl=gen_domain,
                domain_nl=gt_domain_nl,
                target_problem_nl=gen_problem_nl(gt_problem, data_dir=problems_dir.parent),
                target_problem_template=problem_template,
                context_problem_nl=in_context_problem_nl,
                context_problem_pddl=gen_problem,
            )

            gen_problem_file.write_text(gen_problem)

        gen_problem = gen_problem_file.read_text()
    assert gen_problem is not None

    # check problem
    try:
        get_predicate_evaluation(list(ProblemParser()(gen_problem).init))
    except RuntimeError as e:
        print(f"Error evaluating predicates in generated problem {problem_file.name}: {e}. Skipping.")
        return None
    except Exception as e:
        print(f"Unexpected error evaluating predicates in generated problem {problem_file.name}: {e}. Skipping.")
        return None

    # gen_domain_pddl = PDDLDomain.loads((task_dir / "generated-domain.json").read_text()).to_pddl()
    gen_domain = gen_domain.replace("(or )", "(and)")  # hotfix
    gen_domain_obj = DomainParser()(gen_domain)
    assert isinstance(gen_domain_obj, Domain)

    # prune domain to only contain task-relevant actions
    gen_function_mapping_config = {k: v for k, v in gen_function_mapping_config.items() if isinstance(v, dict)}
    gen_domain_obj = copy_domain_w_args(
        gen_domain_obj, actions=[a for a in gen_domain_obj.actions if a.name in gen_function_mapping_config]
    )

    unknown_mappings = [k for k, v in gen_function_mapping_config.items() if v["name"] not in function_signatures]
    if len(unknown_mappings) > 0:
        raise ValueError(f"Unknown function mappings found in {problem_file.name}: {unknown_mappings}. ")

    # map action names
    gen_function_mapping = get_function_mapping_from_config(
        gen_function_mapping_config, gt_function_signatures=function_signatures
    )

    gen_domain = str(gen_domain_obj)
    gen_domain = gen_domain.replace("(or )", "(and)")  # hotfix
    gt_domain = gt_domain.replace("(or )", "(and)")  # hotfix

    target_env = RandomWalkEnv(domain_pddl=gt_domain, problem_pddl=gt_problem, function_mapping=gt_function_mapping)

    try_validate_plan = partial(
        _eval_finds_plan,
        target_env=target_env,
        gt_domain=gt_domain,
        gt_problem=gt_problem,
        gen_function_mapping=gen_function_mapping,
    )

    # evaluate plan success - find and verify plans one by one
    plan_success = False
    res_plan = None

    # gen plan first
    if not plan_success and gen_plan is not None:
        plan_success = try_validate_plan(gen_plan)
        if plan_success:
            res_plan = gen_plan

    if not plan_success:
        _, fd_plan_success, ai_plan = FastDownwardAIPlanner(alias="lama-first", search_time_limit=60).plan(
            domain=DomainParser()(gen_domain), problem=ProblemParser()(gen_problem)
        )
        if ai_plan is not None:
            assert fd_plan_success
            plan_success = try_validate_plan(ai_plan)
            if plan_success:
                res_plan = ai_plan

    # Try hierarchical plan if FastDownward plan didn't work
    if not plan_success and gen_pddl_domain is not None:
        try:
            ai_h_plan = HierarchicalAIPlanner().plan(domain=gen_pddl_domain, problem=ProblemParser()(gen_problem))
        except Exception as e:
            print(f"Error generating hierarchical plan for task {problem_file.name}: {e}")
            ai_h_plan = None
        if ai_h_plan is not None:
            plan_success = try_validate_plan(ai_h_plan)
            if plan_success:
                res_plan = ai_h_plan

    # Try truncated generated plans if other plans didn't work
    if not plan_success and gen_plan is not None:
        original_gen_plan = gen_plan
        for i in range(len(original_gen_plan.actions) - 1, 0, -1):
            truncated_plan = SasPlan(original_gen_plan.actions[:i])
            plan_success = try_validate_plan(truncated_plan)
            if plan_success:
                res_plan = truncated_plan
                break

    # filter the skills whose domain should be evaluated -> for task-centric domain evaluation
    if res_plan is not None:
        gen_skills = list(
            set(
                [
                    gen_function_mapping[action.name](action)
                    for action in res_plan.actions
                    if action.name in gen_function_mapping_config
                ]
            )
        )
    else:
        gen_skills = None

    gen_domain_for_task = get_task_centric_domain(
        domain=gen_domain,
        config={k: v["name"] for k, v in gen_function_mapping_config.items()},
        task=domain,
        task_name=task_name,
        skills=gen_skills,
    )
    gt_domain_for_task = get_task_centric_domain(
        domain=gt_domain,
        config={k: v["name"] for k, v in gt_function_mapping_config.items()},
        task=domain,
        task_name=task_name,
        skills=gen_skills,
    )

    task_evaluator = RandomWalkEvaluator(gt_domain=gt_domain_for_task, gt_function_mapping=gt_function_mapping)

    try:
        _, t_to_gen_frac, gen_to_t_frac = task_evaluator.evaluate_task(
            gen_domain_for_task,
            gt_problem=gt_problem,
            gen_problem=gen_problem,
            gen_function_mapping=gen_function_mapping,
        )
    except (BaseException, SystemExit) as e:
        import traceback

        traceback.print_exc()
        print(f"Error evaluating task {problem_file.name}: {e}")
        t_to_gen_frac = 0
        gen_to_t_frac = 0

    return {
        "task": problem_file.stem,
        "planning_success": plan_success,
        **tokens,
        "t_to_gen_frac": t_to_gen_frac,
        "gen_to_t_frac": gen_to_t_frac,
    }


def _variant_result_dirs(args):
    textgen_api = TextGenApi.default(args.llm)
    domain_dir = Path(__file__).parent.parent / ("furniturebench" if args.domain.startswith("fb-") else "ipc")
    domain_results_dir = domain_dir / "results" / args.domain
    domain_result_dir = domain_results_dir / textgen_api.connections.connections[0].model_dir
    if args.variant == "sadegh":
        variant_dir_name = "sadegh"
    elif args.variant == "lodge":
        variant_dir_name = "hi-tamp"
    elif args.variant == "lionel":
        variant_dir_name = "lionel"
    elif args.variant == "cluster":
        variant_dir_name = "cluster-intersect"
    elif args.variant == "guan":
        variant_dir_name = "guan"
    else:
        raise ValueError(f"Unknown variant: {args.variant}")

    variant_result_dir = natsorted((domain_result_dir / variant_dir_name).glob(args.suffix))
    variant_result_dir = [d for d in variant_result_dir if d.is_dir()]

    return variant_result_dir


def eval_all(args):
    load_dotenv()
    textgen_api = TextGenApi.default(args.llm)

    # ipc setup
    domain_dir = Path(__file__).parent.parent / ("furniturebench" if args.domain.startswith("fb-") else "ipc")
    domain_dir = domain_dir / "data" / args.domain

    function_stubs = (domain_dir / "function_stubs.py").read_text()
    functions = {fd.name: fd for fd in ast.parse(function_stubs).body}
    function_signatures = {f.name: len(f.args.args) for f in ast.parse(function_stubs).body}  # type: ignore

    gt_domain = (domain_dir / "domain.pddl").read_text()
    gt_function_mapping_config = json.loads((domain_dir / "function_mapping.json").read_text())
    gt_function_mapping = get_function_mapping_from_config(gt_function_mapping_config)
    gt_domain_nl = gen_domain_nl(domain_dir, use_docstrings=True)
    problems_dir = domain_dir / "problems"

    data_dirs = [d for d in natsorted(problems_dir.glob("p*.pddl")) if re.match(r"p\d{2}\.pddl", d.name)]

    variant_result_dirs = _variant_result_dirs(args)

    for variant_result_dir in variant_result_dirs:
        print(f"Evaluating variant results in {variant_result_dir}")
        if args.variant == "guan":
            gen_domain_file = variant_result_dir / "domain.pddl"
            func_mapping_file = variant_result_dir / "function_mapping.json"
            if not func_mapping_file.exists():
                gen_function_mapping_config = map_functions(
                    textgen_api=textgen_api,
                    domain=DomainParser()(gen_domain_file.read_text()),
                    functions=functions,
                    existing_mapping={},
                )
                func_mapping_file.write_text(json.dumps(gen_function_mapping_config, indent=2))

        if not variant_result_dir.is_dir():
            raise RuntimeError(
                f"Domains not generated yet. Please run the domain generation script first. ({variant_result_dir})"
            )

        eval_walks_file = variant_result_dir / "eval-random-walks.json"

        if not args.force and eval_walks_file.exists():
            print(eval_walks_file)
            print(f"Evaluation already done for {args.domain} ({args.variant}). Skipping.")
            continue

        def catch_eval_sample(*a, **k):
            try:
                return eval_sample(*a, **k)
            except Exception as e:
                print(f"Error evaluating sample: {e}")
                return None

        eval_data = Parallel(n_jobs=-1 if args.parallel else 1)(
            delayed(catch_eval_sample)(
                gt_domain=gt_domain,
                gt_function_mapping_config=gt_function_mapping_config,
                gt_domain_nl=gt_domain_nl,
                problem_file=problem_file,
                problems_dir=problems_dir,
                domain=args.domain,
                variant=args.variant,
                textgen_api=textgen_api,
                gt_function_mapping=gt_function_mapping,
                variant_result_dir=variant_result_dir,
                functions=functions,
                function_signatures=function_signatures,
            )
            for problem_file in tqdm(data_dirs, desc="Evaluating problems")
            # if problem_file.stem == "p23"
        )
        eval_data = [d for d in eval_data if d is not None]  # filter out None results
        if len(eval_data) == 0:
            print(f"No valid evaluation data for {variant_result_dir}. Skipping.")
            continue

        df = pd.DataFrame(eval_data)
        print(df.loc[:, df.columns != "call_id"])

        if len(df) == 0:
            t_to_gen_frac = 0
            gen_to_t_frac = 0
            planning_success = 0
        else:
            t_to_gen_frac = df["t_to_gen_frac"].mean()
            gen_to_t_frac = df["gen_to_t_frac"].mean()
            planning_success = df["planning_success"].mean()
        final_score = harmonic_mean(t_to_gen_frac, gen_to_t_frac)

        # tokens_data = {k: df[k].mean() if args.variant != "lodge" else df[k].sum() for k in df.columns if "tokens" in k}
        tokens_data = df[[c for c in df.columns if "tokens" in c]]
        tokens_data = (tokens_data.mean() if args.variant != "lodge" else tokens_data.sum()).to_dict()

        eval_random_walks = {
            "t_to_gen_frac": t_to_gen_frac,
            "gen_to_t_frac": gen_to_t_frac,
            "planning_sr": planning_success,
            # **{k: df[k].mean() for k in df.columns if "tokens" in k},
            **tokens_data,
            "harmonic_mean": final_score,
            "aux": json.loads(df.to_json()),
        }
        print(
            f"Harmonic mean: {final_score:.2f} | t_to_gen_frac: {t_to_gen_frac:.2f} | gen_to_t_frac: {gen_to_t_frac:.2f} | planning_sr: {planning_success:.2f}"
        )
        eval_walks_file.write_text(json.dumps(eval_random_walks, indent=2))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--domain", type=str, choices=["logistics", "household", "fb-lamp"], required=True)
    parser.add_argument("--llm", type=str, required=True)
    parser.add_argument("--variant", type=str, choices=["sadegh", "lionel", "lodge", "guan", "cluster"], required=True)
    parser.add_argument("--suffix", type=str, required=True)
    parser.add_argument("--parallel", action="store_true", help="Run evaluation in parallel")
    parser.add_argument(
        "--force", action="store_true", help="Force regeneration of problems even if they already exist"
    )
    args = parser.parse_args()

    eval_all(args)
