import argparse
import json
import os
from pathlib import Path
import shutil
import subprocess

from dotenv import load_dotenv
from llm_utils.textgen_api.textgen_api import TextGenApi
from demos.scripts.utils import gen_domain_nl, gen_problem_nl

demos_root_dir = Path(__file__).parent.parent


def _demo_root_dir(args) -> Path:
    return demos_root_dir / ("furniturebench" if args.domain.startswith("fb") else "ipc")


def main(args):
    load_dotenv()

    connection = TextGenApi.default(args.llm).connections.connections[0]

    data_dir = _demo_root_dir(args) / "data" / args.domain
    exp_dir = (
        _demo_root_dir(args)
        / "results"
        / args.domain
        / connection.model_dir
        / "sadegh"
        / ("no-docstrings" if args.no_docstrings else "docstrings")
    )

    sadegh_root = demos_root_dir.parent / "3rdparty/llm-pddl-planning"
    sadegh_data_dir = sadegh_root / "data/domains" / args.domain
    if sadegh_data_dir.is_dir():
        shutil.rmtree(sadegh_data_dir)
    sadegh_data_dir.mkdir(exist_ok=True)

    # copy data to sadegh
    for file_name in [
        "domain.pddl",
        "predicate_descriptor.py",
    ]:
        src_file = data_dir / file_name
        dest_file = sadegh_data_dir / file_name
        shutil.copyfile(src_file, dest_file)

    (sadegh_data_dir / "domain.nl").write_text(gen_domain_nl(data_dir, not args.no_docstrings))

    for problem_file in (data_dir / "problems").glob("p*.pddl"):
        if "p01" in problem_file.stem:
            continue  # used for in-context
        shutil.copyfile(problem_file, sadegh_data_dir / problem_file.name)
        (sadegh_data_dir / f"{problem_file.stem}.nl").write_text(gen_problem_nl(problem_file.read_text(), data_dir))

    env = os.environ.copy()
    env.update(
        {
            "OPENAI_KEY": os.environ["OPENAI_API_KEY"],
            "GPT_MODEL": connection.model,
        }
    )

    sadegh_python_path = sadegh_root / ".pixi/envs/default/bin/python"

    for seed in [42, 43, 44]:
        exp_out_dir = exp_dir / f"seed_{seed}"
        if exp_out_dir.exists():
            print(f"Experiment directory {exp_dir} already exists. Please remove it before running again.")
            continue

        # create template files
        if len(list(sadegh_data_dir.glob("*template.pddl"))) == 0:
            response = subprocess.run(
                [sadegh_python_path, sadegh_root / "src/gen_pddl_template_pddl.py", "--domains", args.domain],
                env=env,
            )
            print(response.stdout)
            assert response.returncode == 0, "Failed to generate template files."

        response = subprocess.run(
            [
                sadegh_python_path,
                sadegh_root / "src/main.py",
                # args
                "--cfg.log_prefix=",
                f"--cfg.seed={seed}",
                "--cfg.planning_strategy_args.best_of_n=10",
                # "--cfg.planning_strategy_args.hide_action_signature=True",
                # "--cfg.problem_translation_args.add_domain_proposal=True",
                "--cfg.problem_translation_args.n_candidates=5",
                "--cfg.problem_translation_args.active=True",
                f"--cfg.target_domain_name={args.domain}",
                f"--cfg.exp_path={exp_dir.as_posix()}",
                f"--cfg.eval_best_domain=False",
                "--cfg.wandb_args.entity=claudik",
            ],
            env=env,
            cwd=sadegh_root,
        )
        if response.returncode != 0:
            print("Error running sadegh main.py:")
            print(response.returncode)
            if response.stdout is not None:
                print(response.stdout.decode("utf-8"))
            if response.stderr is not None:
                print(response.stderr.decode("utf-8"))
            raise RuntimeError("Failed to run sadegh main.py")

        domain = json.loads((exp_out_dir / "summary_logs.json").read_text())["best_gen_domain_pddl"]
        (exp_out_dir / "domain.pddl").write_text(domain)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--domain", type=str, required=True)
    argparser.add_argument("--llm", type=str, required=True)
    argparser.add_argument("--no-docstrings", action="store_true", help="Use domain without docstrings")
    main(argparser.parse_args())
