import copy
from tqdm import trange
import wandb
import logging
import traceback
from pathlib import Path

from python_utils.string_utils import remove_docstrings
from typing import Callable

from llm_utils.textgen_api.textgen_api import TextGenApi

from tp_lodge.motion_planning.local_motion_validator import PDDLDomain, PDDLProblem
from tp_lodge.motion_planning.motion_validator import MotionValidator
from tp_lodge.task_planning.pddl_planner.hi_planner.nl_action_node import NLActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.out_of_retries_error import OutOfRetriesException
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
import argparse
import json
from pathlib import Path

from dotenv import load_dotenv
from llm_utils.textgen_api.textgen_api import TextGenApi, logger
from state_estimation import PredicateGrounder

from demos.ipc.src.household.household_motion_validator import HouseholdMotionValidator
from demos.ipc.src.household.se_household_motion_validator import SEHouseholdMotionValidator
from demos.ipc.src.logistics.logistics_motion_validator import LogisticsMotionValidator
from demos.ipc.src.planning_benchmark_sample_generator import BenchmarkSampleGenerator, PlanningBenchmarkSampleGenerator
from tp_lodge.utils.pddl_utils import filter_formula_by_predicates
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
import logging

logging.basicConfig(level=logging.INFO, format="%(message)s")


def run_single(
    motion_validator: MotionValidator,
    textgen_api: TextGenApi,
    storage: SharedActionNodeStorage,
    instruction: str,
    domain: PDDLDomain,
    problem_skeleton: PDDLProblem,
    out_dir: Path,
    clear_backup: bool = False,
):
    motion_validator.init_state(seed=0)
    out_dir.mkdir(exist_ok=True)
    storage.reset(root_plan_dir=out_dir, clear_backup=clear_backup)
    textgen_api.usage.reset()

    planner = NLActionNode(
        storage=storage,
        motion_validator=motion_validator,
        textgen_api=textgen_api,
        out_dir=out_dir,
        domain=domain,
        problem_skeleton=problem_skeleton,
        instruction=instruction,
    )
    planner.delete_env_hashes()

    try:
        plan_result, last_domain, goal_state = planner.plan(stdout_level=logging.DEBUG)
    except OutOfRetriesException as e:
        exception = "%s\n%s" % (str(e), traceback.format_exc())
        print(exception)
        (out_dir / "failed").write_text(exception)
    except Exception as e:
        exception = "%s\n%s" % (str(e), traceback.format_exc())
        print(exception)
        (out_dir / "failed").write_text(exception)
        raise e
    finally:
        # save metrics
        (out_dir / "storage.json").write_text(storage.to_dumps())
        (out_dir / "textgen-api-usage.json").write_text(textgen_api.usage.to_dumps())

    last_domain = planner.last_domain

    (out_dir / "domain.json").write_text(last_domain.dumps())

    return last_domain


def run_sample(
    textgen_api: TextGenApi,
    get_motion_validator: Callable[[Path], MotionValidator],
    domain_skeleton: PDDLDomain,
    storage: SharedActionNodeStorage,
    generator: BenchmarkSampleGenerator,
    s_out_dir: Path,
    clear: bool,
    *,
    learn_one_domain: bool = False,
):
    shared_domain = domain_skeleton

    for i in trange(len(generator), desc="LODGE Planning tasks"):
        t_out_dir = s_out_dir / f"task-{i}"
        learned_init_state = None
        if learn_one_domain and i > 0:
            shared_domain = shared_domain.only_verified_operators()
            shared_domain = shared_domain.copy_with(
                predicates=[p for p in shared_domain.predicates if p.pred_type == "other" or p.predefined]
            )

            last_problem_file = s_out_dir / f"task-{i-1}" / "generated-problem.json"
            if last_problem_file.is_file():
                last_problem = PDDLProblem.loads(last_problem_file.read_text())

                learned_init_state = filter_formula_by_predicates(
                    last_problem.initial_state,
                    known_predicates=[
                        # state preds change for every new problem
                        p.name
                        for p in shared_domain.predicates
                        if not p.predefined and p.pred_type == "other"
                    ],
                )

        if not t_out_dir.is_dir():
            # shutil.rmtree(out_dir)

            if wandb.run is None:
                wandb.init()

            instruction, env_state, problem_skeleton = generator.generate(i)
            logger.info("Planning task %d: %s..." % (i, instruction))

            if learned_init_state is not None:
                problem_skeleton = problem_skeleton.copy_with(initial_state=learned_init_state)

            t_out_dir.mkdir(exist_ok=True)

            motion_validator = get_motion_validator(t_out_dir)
            assert hasattr(motion_validator, "env")
            assert hasattr(motion_validator.env, "_init_env_state")
            motion_validator.env._init_env_state = copy.deepcopy(env_state)
            motion_validator.env._set_state(new_state=env_state)

            problem_skeleton = motion_validator.inject_init_predicates(domain=shared_domain, problem=problem_skeleton)

            last_domain = run_single(
                motion_validator=motion_validator,
                textgen_api=textgen_api,
                storage=storage,
                instruction=instruction,
                domain=copy.deepcopy(shared_domain),
                problem_skeleton=problem_skeleton,
                out_dir=t_out_dir,
                clear_backup=clear,
            )
        else:
            last_domain = PDDLDomain.loads((t_out_dir / "domain.json").read_text())

        if learn_one_domain:
            updated_ops = {}
            for op in last_domain.operators:
                try:
                    shared_op = shared_domain.get_operator_by_id(op.id)
                    if op.verified or not shared_op.verified:
                        # use new operator
                        updated_ops[op.id] = op
                    else:
                        # keep old operator
                        updated_ops[shared_op.id] = shared_op
                except KeyError:
                    updated_ops[op.id] = op
            for shared_op in shared_domain.operators:
                if shared_op.id not in updated_ops:
                    updated_ops[shared_op.id] = shared_op

            shared_domain = shared_domain.copy_with(
                operators=list(updated_ops.values()), predicates=last_domain.predicates
            )


def main(args):
    load_dotenv()
    assert args.domain in ["logistics", "household"]

    root_dir = Path(__file__).parent.parent
    data_dir = root_dir / "data" / args.domain
    # out_dir = data_dir / "out"
    textgen_api = TextGenApi.default(connection=args.llm)
    assert len(textgen_api.connections.connections) == 1
    llm_id = textgen_api.connections.connections[0].model_dir

    if args.tmp:
        out_dir = root_dir / "out"
    else:
        out_dir = root_dir / "results" / args.domain / llm_id / "hi-tamp" / args.suffix.split("/sample")[0]
    out_dir.mkdir(exist_ok=True, parents=True)

    domain_knowledge = (data_dir / "domain_knowledge.txt").read_text()
    function_stubs = (data_dir / "function_stubs.py").read_text()
    if args.add_unnecessary_skills:
        # Add unnecessary skills to the function stubs
        unnecessary_skills = (data_dir / "unnecessary_skills.py").read_text()
        function_stubs += "\n\n" + unnecessary_skills
    if args.no_docstrings:
        function_stubs = remove_docstrings(function_stubs)

    def get_motion_validator(out_dir: Path) -> MotionValidator:
        if not args.learn_predicates:
            if args.domain == "logistics":
                return LogisticsMotionValidator(hide_env_feedback=args.hide_env_feedback)
            elif args.domain == "household":
                return HouseholdMotionValidator(hide_env_feedback=args.hide_env_feedback)
            else:
                raise ValueError("Unknown domain: %s" % args.domain)
        else:
            if args.domain == "logistics":
                raise NotImplementedError("Predicate learning is not implemented for logistics domain yet.")
            elif args.domain == "household":
                code_api_file = root_dir / "src/household/household_predicates_object_api.py"
                return SEHouseholdMotionValidator(
                    grounder=PredicateGrounder(
                        code_api_file=code_api_file,
                        out_dir=out_dir,
                        textgen_api=textgen_api,
                        domain_knowledge=domain_knowledge,
                    )
                )
            else:
                raise ValueError("Unknown domain: %s" % args.domain)

    domain_skeleton = PDDLDomain.from_json(json.loads((data_dir / "domain_skeleton.json").read_text()))

    if args.learn_predicates:
        assert args.domain == "household"
        domain_skeleton = domain_skeleton.copy_with(
            types={
                k: v
                for k, v in domain_skeleton.types.items()
                if k in ["household_object", "furniture_appliance", "small_receptacle", "small_items", "agent"]
            },
            predicates=[],
        )

    storage = SharedActionNodeStorage(
        domain_knowledge=domain_knowledge,
        function_stubs=function_stubs,
        n_val_retries=args.n_val_retries,
        n_env_retries=args.n_env_retries,
        use_domain_knowledge=args.use_domain_knowledge,
        use_ai_plan=args.use_ai_planner,
        ai_planner_kwargs={
            "alias": "lama-first",
            # "alias": "seq-sat-fdss-2023",
            "search_time_limit": 30,
        },
    )
    for key, value in vars(args).items():
        print(f"{key}: {value}")

    (out_dir / "config.json").write_text(storage.to_dumps())

    generator = PlanningBenchmarkSampleGenerator(args.domain, use_hl_types=True)

    if args.domain == "household":
        # ensure the objects are always the same (needed to copy init state)
        objects_list = set()
        for i in range(len(generator)):
            instruction, env_state, problem_skeleton = generator.generate(i)
            objects_list.add(frozenset(problem_skeleton.objects))
        assert len(objects_list) == 1

    if args.show_op_signature:
        from pddl.parser.domain import DomainParser
        from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
        from tp_lodge.utils.pddl_lib_utils import copy_action_w_args
        from pddl.core import And

        assert not args.learn_predicates

        gt_domain = DomainParser()((data_dir / "domain.pddl").read_text())
        ops = [
            PDDLOperator(
                definition=copy_action_w_args(op, precondition=And(), effect=And()),
                description="",
            )
            for op in gt_domain.actions
        ]
        domain_skeleton = domain_skeleton.copy_with(operators=ops)

    if False:
        i = 22
        instruction, env_state, problem_skeleton = generator.generate(i)
        logger.info("Planning task %d: %s..." % (i, instruction))

        s_out_dir = out_dir / "sample-1" / ("task-%d" % i)
        s_out_dir.mkdir(exist_ok=True)
        motion_validator = get_motion_validator(s_out_dir)
        assert hasattr(motion_validator, "env")
        assert hasattr(motion_validator.env, "_init_env_state")
        import copy

        motion_validator.env._init_env_state = copy.deepcopy(env_state)
        motion_validator.env._set_state(new_state=env_state)

        run_single(
            instruction=instruction,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            storage=storage,
            domain=domain_skeleton,
            problem_skeleton=problem_skeleton,
            out_dir=s_out_dir,
        )

    else:
        for i in range(1, 4):
            print("Running sample %d..." % i)
            s_out_dir = out_dir / f"sample-{i}"
            s_out_dir.mkdir(exist_ok=True)
            run_sample(
                get_motion_validator=get_motion_validator,
                textgen_api=textgen_api,
                storage=storage,
                domain_skeleton=domain_skeleton,
                generator=generator,
                s_out_dir=s_out_dir,
                clear=True,
                learn_one_domain=args.learn_one_domain,
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct action models using LLMs.")
    parser.add_argument("--llm", type=str, help="Specify the LLM engine to use.", required=True)
    parser.add_argument("--domain", type=str, help="Specify the task.", required=True)
    parser.add_argument("--use-domain-knowledge", action="store_true", help="Use domain knowledge.")
    parser.add_argument("--use-ai-planner", action="store_true", help="Use AI planner.")
    parser.add_argument("--no-docstrings", action="store_true", help="Do not include docstrings in the output.")
    parser.add_argument("--learn-one-domain", action="store_true", help="Learn one domain for all tasks.")
    parser.add_argument("--hide-env-feedback", action="store_true")
    parser.add_argument("--add-unnecessary-skills", action="store_true")
    parser.add_argument("--learn-predicates", action="store_true")
    parser.add_argument("--n-val-retries", type=int, default=20, help="Number of retries for validation.")
    parser.add_argument("--n-env-retries", type=int, default=10, help="Number of retries for env feedback.")
    parser.add_argument("--show-op-signature", action="store_true", help="Show operator signature in the prompt.")
    parser.add_argument("--tmp", action="store_true", help="Use a temporary directory for output.")
    parser.add_argument("--suffix", type=str, required=True)
    args = parser.parse_args()
    main(args)
